def test_accuracy_computation(self):
        accuracy = BooleanAccuracy()
        predictions = torch.Tensor([[0, 1],
                                    [2, 3],
                                    [4, 5],
                                    [6, 7]])
        targets = torch.Tensor([[0, 1],
                                [2, 2],
                                [4, 5],
                                [7, 7]])
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 2. / 4

        mask = torch.ones(4, 2)
        mask[1, 1] = 0
        accuracy(predictions, targets, mask)
        assert accuracy.get_metric() == 5. / 8

        targets[1, 1] = 3
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 8. / 12

        accuracy.reset()
        accuracy(predictions, targets)
        assert accuracy.get_metric() == 3. / 4
Exemplo n.º 2
0
span_accuracy_function = BooleanAccuracy()
squad_metrics_function = SquadEmAndF1()

# Compute the loss for training.
if span_start is not None:
    span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
    span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
    loss = span_start_loss + span_end_loss
    
    span_start_accuracy_function(span_start_logits, span_start.squeeze(-1))
    span_end_accuracy_function(span_end_logits, span_end.squeeze(-1))
    span_accuracy_function(best_span, torch.stack([span_start, span_end], -1))

    span_start_accuracy = span_start_accuracy_function.get_metric()
    span_end_accuracy =  span_end_accuracy_function.get_metric()
    span_accuracy = span_accuracy_function.get_metric()


    print ("Loss: ", loss)
    print ("span_start_accuracy: ", span_start_accuracy)
    print ("span_start_accuracy: ", span_start_accuracy)
    print ("span_end_accuracy: ", span_end_accuracy)
    
# Compute the EM and F1 on SQuAD and add the tokenized input to the output.
if metadata is not None:
    best_span_str = []
    question_tokens = []
    passage_tokens = []
    for i in range(batch_size):
        question_tokens.append(metadata[i]['question_tokens'])
        passage_tokens.append(metadata[i]['passage_tokens'])
Exemplo n.º 3
0
 def test_does_not_divide_by_zero_with_no_count(self, device: str):
     accuracy = BooleanAccuracy()
     assert accuracy.get_metric() == pytest.approx(0.0)
Exemplo n.º 4
0
class DialogQA(Model):
    """
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    max_turn_length: ``int``, optional (default=12)
        Maximum length of an interaction.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        phrase_layer: Seq2SeqEncoder,
        residual_encoder: Seq2SeqEncoder,
        span_start_encoder: Seq2SeqEncoder,
        span_end_encoder: Seq2SeqEncoder,
        initializer: Optional[InitializerApplicator] = None,
        dropout: float = 0.2,
        num_context_answers: int = 0,
        marker_embedding_dim: int = 10,
        max_span_length: int = 30,
        max_turn_length: int = 12,
    ) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim,
                                                       self._encoding_dim,
                                                       "x,y,x*y")
        self._merge_atten = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(
                max_turn_length, marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding(
                (num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim,
                                                     self._encoding_dim,
                                                     "x,y,x*y")

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(
            torch.nn.Linear(self._encoding_dim * 3, self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(
            torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(
            phrase_layer.get_input_dim(),
            text_field_embedder.get_output_dim() +
            marker_embedding_dim * num_context_answers,
            "phrase layer input dim",
            "embedding dim + marker dim * num context answers",
        )

        if initializer is not None:
            initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        p1_answer_marker: torch.IntTensor = None,
        p2_answer_marker: torch.IntTensor = None,
        p3_answer_marker: torch.IntTensor = None,
        yesno_list: torch.IntTensor = None,
        followup_list: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        token_character_ids = question["token_characters"]["token_characters"]
        batch_size, max_qa_count, max_q_len, _ = token_character_ids.size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question,
                                                      num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
            self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(
            self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1)
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage)

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(
                max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(
                question_num_ind)
            embedded_question = torch.cat(
                [embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = (embedded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1,
                1).view(total_qa_count, passage_length,
                        self._text_field_embedder.get_output_dim()))
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count,
                                                     passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat(
                [repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(
                    total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat(
                    [repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(
                        total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(
                        p3_answer_marker)
                    repeated_embedded_passage = torch.cat(
                        [repeated_embedded_passage, p3_answer_marker_emb],
                        dim=-1)

            repeated_encoded_passage = self._variational_dropout(
                self._phrase_layer(repeated_embedded_passage,
                                   repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(
                self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(
                total_qa_count, passage_length, self._encoding_dim)

        encoded_question = self._variational_dropout(
            self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(
            repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(total_qa_count, passage_length, self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat(
            [
                repeated_encoded_passage,
                passage_question_vectors,
                repeated_encoded_passage * passage_question_vectors,
                repeated_encoded_passage * tiled_question_passage_vector,
            ],
            dim=-1,
        )

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(
            self._residual_encoder(final_merged_passage,
                                   repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer,
                                                     residual_layer)

        mask = repeated_passage_mask.reshape(
            total_qa_count, passage_length, 1) * repeated_passage_mask.reshape(
                total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              dtype=torch.bool,
                              device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask & ~self_mask

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs,
                                           residual_layer)
        self_attention_vecs = torch.cat([
            self_attention_vecs, residual_layer,
            residual_layer * self_attention_vecs
        ],
                                        dim=-1)
        residual_layer = F.relu(
            self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage,
                                             repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(
            torch.cat([final_merged_passage, start_rep], dim=-1),
            repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(
            -1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(
            span_start_logits,
            span_end_logits,
            span_yesno_logits,
            span_followup_logits,
            self._max_span_length,
        )

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits,
                                        repeated_passage_mask),
                span_start.view(-1),
                ignore_index=-1,
            )
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits,
                                        repeated_passage_mask),
                span_end.view(-1),
                ignore_index=-1,
            )
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(
                best_span[:, 0:2],
                torch.stack([span_start, span_end],
                            -1).view(total_qa_count, 2),
                mask=qa_mask.unsqueeze(1).expand(-1, 2),
            )
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1),
                             followup_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)
            self._span_followup_accuracy(_followup,
                                         followup_list.view(-1),
                                         mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict["best_span_str"] = []
        output_dict["qid"] = []
        output_dict["followup"] = []
        output_dict["yesno"] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]["original_passage"]
            offsets = metadata[i]["token_offsets"]
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad.metric_max_over_ground_truths(
                                    squad.f1_score, best_span_string, refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad.metric_max_over_ground_truths(
                            squad.f1_score, best_span_string, answer_texts)
                self._official_f1(100 * f1_score)
            output_dict["qid"].append(per_dialog_query_id_list)
            output_dict["best_span_str"].append(per_dialog_best_span_list)
            output_dict["yesno"].append(per_dialog_yesno_list)
            output_dict["followup"].append(per_dialog_followup_list)
        return output_dict

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        yesno_tags = [[
            self.vocab.get_token_from_index(x, namespace="yesno_labels")
            for x in yn_list
        ] for yn_list in output_dict.pop("yesno")]
        followup_tags = [[
            self.vocab.get_token_from_index(x, namespace="followup_labels")
            for x in followup_list
        ] for followup_list in output_dict.pop("followup")]
        output_dict["yesno"] = yesno_tags
        output_dict["followup"] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "yesno": self._span_yesno_accuracy.get_metric(reset),
            "followup": self._span_followup_accuracy.get_metric(reset),
            "f1": self._official_f1.get_metric(reset),
        }

    @staticmethod
    def _get_best_span_yesno_followup(
        span_start_logits: torch.Tensor,
        span_end_logits: torch.Tensor,
        span_yesno_logits: torch.Tensor,
        span_followup_logits: torch.Tensor,
        max_span_length: int,
    ) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4),
                                                     dtype=torch.long)

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()
        span_followup_logits = span_followup_logits.data.cpu().numpy()
        for b_i in range(batch_size):
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
Exemplo n.º 5
0
class BertQA(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sim_text_field_embedder: TextFieldEmbedder,
                 loss_weights: Dict,
                 sim_class_weights: List,
                 pretrained_sim_path: str = None,
                 use_scenario_encoding: bool = True,
                 sim_pretraining: bool = False,
                 dropout: float = 0.2,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BertQA, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        if use_scenario_encoding:
            self._sim_text_field_embedder = sim_text_field_embedder
        self.loss_weights = loss_weights
        self.sim_class_weights = sim_class_weights
        self.use_scenario_encoding = use_scenario_encoding
        self.sim_pretraining = sim_pretraining

        if self.sim_pretraining and not self.use_scenario_encoding:
            raise ValueError(
                "When pretraining Scenario Interpretation Module, you should use it."
            )

        embedding_dim = self._text_field_embedder.get_output_dim()
        self._action_predictor = torch.nn.Linear(embedding_dim, 4)
        self._sim_token_label_predictor = torch.nn.Linear(embedding_dim, 4)
        self._span_predictor = torch.nn.Linear(embedding_dim, 2)
        self._action_accuracy = CategoricalAccuracy()
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_loss_metric = Average()
        self._action_loss_metric = Average()
        self._sim_loss_metric = Average()
        self._sim_yes_f1 = F1Measure(2)
        self._sim_no_f1 = F1Measure(3)

        if use_scenario_encoding and pretrained_sim_path is not None:
            logger.info("Loading pretrained model..")
            self.load_state_dict(torch.load(pretrained_sim_path))
            for param in self._sim_text_field_embedder.parameters():
                param.requires_grad = False

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        initializer(self)

    def get_passage_representation(self, bert_output, bert_input):
        # Shape: (batch_size, bert_input_len)
        input_type_ids = self.get_input_type_ids(
            bert_input['bert-type-ids'], bert_input['bert-offsets'],
            self._text_field_embedder._token_embedders['bert']).float()
        # Shape: (batch_size, bert_input_len)
        input_mask = util.get_text_field_mask(bert_input).float()
        passage_mask = input_mask - input_type_ids  # works only with one [SEP]
        # Shape: (batch_size, bert_input_len, embedding_dim)
        passage_representation = bert_output * passage_mask.unsqueeze(2)
        # Shape: (batch_size, passage_len, embedding_dim)
        passage_representation = passage_representation[:,
                                                        passage_mask.sum(
                                                            dim=0) > 0, :]
        # Shape: (batch_size, passage_len)
        passage_mask = passage_mask[:, passage_mask.sum(dim=0) > 0]

        return passage_representation, passage_mask

    def forward(
            self,  # type: ignore
            bert_input: Dict[str, torch.LongTensor],
            sim_bert_input: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        if self.use_scenario_encoding:
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_bert_input_token_labels_wp = sim_bert_input[
                'scenario_gold_encoding']
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input)
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_input_mask_wp = (sim_bert_input['bert'] != 0).float()
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[
                'bert-type-ids'].float()  # works only with one [SEP]
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze(
                2)
            # Shape: (batch_size, passage_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_passage_representation_wp[:,
                                                                          sim_passage_mask_wp
                                                                          .sum(
                                                                              dim
                                                                              =0
                                                                          ) >
                                                                          0, :]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:,
                                                                         sim_passage_mask_wp
                                                                         .sum(
                                                                             dim
                                                                             =0
                                                                         ) > 0]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_mask_wp = sim_passage_mask_wp[:,
                                                      sim_passage_mask_wp.sum(
                                                          dim=0) > 0]

            # Shape: (batch_size, passage_len_wp, 4)
            sim_token_logits_wp = self._sim_token_label_predictor(
                sim_passage_representation_wp)

            if span_start is not None:  # during training and validation
                class_weights = torch.tensor(self.sim_class_weights,
                                             device=sim_token_logits_wp.device,
                                             dtype=torch.float)
                sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4),
                                         sim_passage_token_labels_wp.view(-1),
                                         ignore_index=0,
                                         weight=class_weights)
                self._sim_loss_metric(sim_loss.item())
                self._sim_yes_f1(sim_token_logits_wp,
                                 sim_passage_token_labels_wp,
                                 sim_passage_mask_wp)
                self._sim_no_f1(sim_token_logits_wp,
                                sim_passage_token_labels_wp,
                                sim_passage_mask_wp)
                if self.sim_pretraining:
                    return {'loss': sim_loss}

            if not self.sim_pretraining:
                # Shape: (batch_size, passage_len_wp)
                bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax(
                    dim=2)) * sim_passage_mask_wp.long()
                # Shape: (batch_size, bert_input_len_wp)
                bert_input_wp_len = bert_input['history_encoding'].size(1)
                if bert_input['scenario_encoding'].size(1) > bert_input_wp_len:
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = bert_input[
                        'scenario_encoding'][:, :bert_input_wp_len]
                else:
                    batch_size = bert_input['scenario_encoding'].size(0)
                    difference = bert_input_wp_len - bert_input[
                        'scenario_encoding'].size(1)
                    zeros = torch.zeros(
                        batch_size,
                        difference,
                        dtype=bert_input['scenario_encoding'].dtype,
                        device=bert_input['scenario_encoding'].device)
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = torch.cat(
                        [bert_input['scenario_encoding'], zeros], dim=1)

        # Shape: (batch_size, bert_input_len + 1, embedding_dim)
        bert_output = self._text_field_embedder(bert_input)
        # Shape: (batch_size, embedding_dim)
        pooled_output = bert_output[:, 0]
        # Shape: (batch_size, bert_input_len, embedding_dim)
        bert_output = bert_output[:, 1:, :]
        # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len)
        passage_representation, passage_mask = self.get_passage_representation(
            bert_output, bert_input)

        # Shape: (batch_size, 4)
        action_logits = self._action_predictor(pooled_output)
        # Shape: (batch_size, passage_len, 2)
        span_logits = self._span_predictor(passage_representation)
        # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1)
        span_start_logits, span_end_logits = span_logits.split(1, dim=2)
        # Shape: (batch_size, passage_len)
        span_start_logits = span_start_logits.squeeze(2)
        # Shape: (batch_size, passage_len)
        span_end_logits = span_end_logits.squeeze(2)

        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "pooled_output": pooled_output,
            "passage_representation": passage_representation,
            "action_logits": action_logits,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        if self.use_scenario_encoding:
            output_dict["sim_token_logits"] = sim_token_logits_wp

        # Compute the loss for training (and for validation)
        if span_start is not None:
            # Shape: (batch_size,)
            span_loss = nll_loss(util.masked_log_softmax(
                span_start_logits, passage_mask),
                                 span_start.squeeze(1),
                                 reduction='none')
            # Shape: (batch_size,)
            span_loss += nll_loss(util.masked_log_softmax(
                span_end_logits, passage_mask),
                                  span_end.squeeze(1),
                                  reduction='none')
            # Shape: (batch_size,)
            more_mask = (label == self.vocab.get_token_index(
                'More', namespace="labels")).float()
            # Shape: (batch_size,)
            span_loss = (span_loss * more_mask).sum() / (more_mask.sum() +
                                                         1e-6)
            if more_mask.sum() > 1e-7:
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(1), more_mask)
                self._span_end_accuracy(span_end_logits, span_end.squeeze(1),
                                        more_mask)
                # Shape: (batch_size, 2)
                span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long()
                self._span_accuracy(best_span,
                                    torch.cat([span_start, span_end], dim=1),
                                    span_acc_mask)

            action_loss = cross_entropy(action_logits, label)
            self._action_accuracy(action_logits, label)

            self._span_loss_metric(span_loss.item())
            self._action_loss_metric(action_loss.item())
            output_dict['loss'] = self.loss_weights[
                'span_loss'] * span_loss + self.loss_weights[
                    'action_loss'] * action_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if not self.training:  # true during validation and test
            output_dict['best_span_str'] = []
            batch_size = len(metadata)
            for i in range(batch_size):
                passage_text = metadata[i]['passage_text']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_str = passage_text[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_str)
                if 'gold_span' in metadata[i]:
                    if metadata[i]['action'] == 'More':
                        gold_span = metadata[i]['gold_span']
                        self._squad_metrics(best_span_str, [gold_span])
        return output_dict

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        action_probs = softmax(output_dict['action_logits'], dim=1)
        output_dict['action_probs'] = action_probs

        predictions = action_probs.cpu().data.numpy()
        argmax_indices = numpy.argmax(predictions, axis=1)
        labels = [
            self.vocab.get_token_from_index(x, namespace="labels")
            for x in argmax_indices
        ]
        output_dict['label'] = labels
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if self.use_scenario_encoding:
            sim_loss = self._sim_loss_metric.get_metric(reset)
            _, _, yes_f1 = self._sim_yes_f1.get_metric(reset)
            _, _, no_f1 = self._sim_no_f1.get_metric(reset)

        if self.sim_pretraining:
            return {'sim_macro_f1': (yes_f1 + no_f1) / 2}

        try:
            action_acc = self._action_accuracy.get_metric(reset)
        except ZeroDivisionError:
            action_acc = 0
        try:
            start_acc = self._span_start_accuracy.get_metric(reset)
        except ZeroDivisionError:
            start_acc = 0
        try:
            end_acc = self._span_end_accuracy.get_metric(reset)
        except ZeroDivisionError:
            end_acc = 0
        try:
            span_acc = self._span_accuracy.get_metric(reset)
        except ZeroDivisionError:
            span_acc = 0

        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        span_loss = self._span_loss_metric.get_metric(reset)
        action_loss = self._action_loss_metric.get_metric(reset)
        agg_metric = span_acc + action_acc * 0.45

        metrics = {
            'action_acc': action_acc,
            'span_acc': span_acc,
            'span_loss': span_loss,
            'action_loss': action_loss,
            'agg_metric': agg_metric
        }

        if self.use_scenario_encoding:
            metrics['sim_macro_f1'] = (yes_f1 + no_f1) / 2

        if not self.training:  # during validation
            metrics['em'] = exact_match
            metrics['f1'] = f1_score

        return metrics

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = torch.triu(
            torch.ones((passage_length, passage_length),
                       device=device)).log().unsqueeze(0)
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)

    def get_input_type_ids(self, type_ids, offsets, embedder):
        "Converts (bsz, seq_len_wp) to (bsz, seq_len_wp) by indexing."
        batch_size = type_ids.size(0)
        full_seq_len = type_ids.size(1)
        if full_seq_len > embedder.max_pieces:  # Recombine if we had used sliding window approach
            assert batch_size == 1 and type_ids.max() > 0
            num_question_tokens = type_ids[0][:embedder.max_pieces].nonzero(
            ).size(0)
            select_indices = embedder.indices_to_select(
                full_seq_len, num_question_tokens)
            type_ids = type_ids[:, select_indices]

        range_vector = util.get_range_vector(
            batch_size, device=util.get_device_of(type_ids)).unsqueeze(1)
        type_ids = type_ids[range_vector, offsets]
        return type_ids
class QaNet(Model):
    """
    This class implements Adams Wei Yu's `QANet Model <https://openreview.net/forum?id=B14TlG-RW>`_
    for machine reading comprehension published at ICLR 2018.

    The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet
    replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the
    modeling layer and output layer.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the passage-question attention.
    matrix_attention_layer : ``MatrixAttention``
        The matrix attention function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    dropout_prob : ``float``, optional (default=0.1)
        If greater than 0, we will apply dropout with this probability between layers.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """

    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        num_highway_layers: int,
        phrase_layer: Seq2SeqEncoder,
        matrix_attention_layer: MatrixAttention,
        modeling_layer: Seq2SeqEncoder,
        dropout_prob: float = 0.1,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        text_embed_dim = text_field_embedder.get_output_dim()
        encoding_in_dim = phrase_layer.get_input_dim()
        encoding_out_dim = phrase_layer.get_output_dim()
        modeling_in_dim = modeling_layer.get_input_dim()
        modeling_out_dim = modeling_layer.get_output_dim()

        self._text_field_embedder = text_field_embedder

        self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim)
        self._highway_layer = Highway(encoding_in_dim, num_highway_layers)

        self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim, encoding_in_dim)
        self._phrase_layer = phrase_layer

        self._matrix_attention = matrix_attention_layer

        self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4, modeling_in_dim)
        self._modeling_layer = modeling_layer

        self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)
        self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._metrics = SquadEmAndF1()
        self._dropout = torch.nn.Dropout(p=dropout_prob) if dropout_prob > 0 else lambda x: x

        initializer(self)

    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(embedded_passage)

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask)
        )
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask)
        )

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True
        )
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True
        )
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat(
                [
                    encoded_passage,
                    passage_question_vectors,
                    encoded_passage * passage_question_vectors,
                    encoded_passage * passage_passage_vectors,
                ],
                dim=-1,
            )
        )

        modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]

        for _ in range(3):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask)
            )
            modeled_passage_list.append(modeled_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)
            )
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)
            )
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])
                passage_str = metadata[i]["original_passage"]
                offsets = metadata[i]["token_offsets"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict["best_span_str"].append(best_span_string)
                answer_texts = metadata[i].get("answer_texts", [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._metrics.get_metric(reset)
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "em": exact_match,
            "f1": f1_score,
        }
Exemplo n.º 7
0
class RobertaSpanPredictionModel(Model):
    """

    """
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model: str = None,
                 requires_grad: bool = True,
                 transformer_weights_model: str = None,
                 layer_freeze_regexes: List[str] = None,
                 on_load: bool = False,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        if on_load:
            logging.info(f"Skipping loading of initial Transformer weights")
            transformer_config = RobertaConfig.from_pretrained(
                pretrained_model)
            self._transformer_model = RobertaModel(transformer_config)

        elif transformer_weights_model:
            logging.info(
                f"Loading Transformer weights model from {transformer_weights_model}"
            )
            transformer_model_loaded = load_archive(transformer_weights_model)
            self._transformer_model = transformer_model_loaded.model._transformer_model
        else:
            self._transformer_model = RobertaModel.from_pretrained(
                pretrained_model)

        for name, param in self._transformer_model.named_parameters():
            grad = requires_grad
            if layer_freeze_regexes and grad:
                grad = not any(
                    [bool(re.search(r, name)) for r in layer_freeze_regexes])
            param.requires_grad = grad

        transformer_config = self._transformer_model.config
        num_labels = 2  # For start/end
        self.qa_outputs = Linear(transformer_config.hidden_size, num_labels)

        # Import GTP2 machinery to get from tokens to actual text
        self.byte_decoder = {v: k for k, v in bytes_to_unicode().items()}

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._debug = 2
        self._padding_value = 1  # The index of the RoBERTa padding token

    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                segment_ids: torch.LongTensor = None,
                start_positions: torch.LongTensor = None,
                end_positions: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = tokens['tokens']

        batch_size = input_ids.size(0)
        num_choices = input_ids.size(1)

        tokens_mask = (input_ids != self._padding_value).long()

        if self._debug > 0:
            print(f"batch_size = {batch_size}")
            print(f"num_choices = {num_choices}")
            print(f"tokens_mask = {tokens_mask}")
            print(f"input_ids.size() = {input_ids.size()}")
            print(f"input_ids = {input_ids}")
            print(f"segment_ids = {segment_ids}")
            print(f"start_positions = {start_positions}")
            print(f"end_positions = {end_positions}")

        # Segment ids are not used by RoBERTa

        transformer_outputs = self._transformer_model(
            input_ids=input_ids,
            # token_type_ids=segment_ids,
            attention_mask=tokens_mask)
        sequence_output = transformer_outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        span_start_logits = util.replace_masked_values(start_logits,
                                                       tokens_mask, -1e7)
        span_end_logits = util.replace_masked_values(end_logits, tokens_mask,
                                                     -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)
        span_start_probs = util.masked_softmax(span_start_logits, tokens_mask)
        span_end_probs = util.masked_softmax(span_end_logits, tokens_mask)
        output_dict = {
            "start_logits": start_logits,
            "end_logits": end_logits,
            "best_span": best_span
        }
        output_dict["start_probs"] = span_start_probs
        output_dict["end_probs"] = span_end_probs

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            self._span_start_accuracy(span_start_logits, start_positions)
            self._span_end_accuracy(span_end_logits, end_positions)
            self._span_accuracy(
                best_span,
                torch.cat([
                    start_positions.unsqueeze(-1),
                    end_positions.unsqueeze(-1)
                ], -1))

            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
            # Should we mask out invalid positions here?
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            output_dict["loss"] = total_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['exact_match'] = []
            output_dict['f1_score'] = []
            tokens_texts = []
            for i in range(batch_size):
                tokens_text = metadata[i]['tokens']
                tokens_texts.append(tokens_text)
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                predicted_start = predicted_span[0]
                predicted_end = predicted_span[1]
                predicted_tokens = tokens_text[predicted_start:(predicted_end +
                                                                1)]
                best_span_string = self.convert_tokens_to_string(
                    predicted_tokens)
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                exact_match = 0
                f1_score = 0
                if answer_texts:
                    exact_match, f1_score = self._squad_metrics(
                        best_span_string, answer_texts)
                output_dict['exact_match'].append(exact_match)
                output_dict['f1_score'].append(f1_score)
            output_dict['tokens_texts'] = tokens_texts

        if self._debug > 0:
            print(f"output_dict = {output_dict}")

        return output_dict

    def convert_tokens_to_string(self, tokens):
        """ Converts a sequence of tokens (string) in a single string. """
        text = ''.join(tokens)
        text = bytearray([self.byte_decoder[c]
                          for c in text]).decode('utf-8', errors='replace')
        return text

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @classmethod
    def _load(cls,
              config: Params,
              serialization_dir: str,
              weights_file: str = None,
              cuda_device: int = -1,
              **kwargs) -> 'Model':
        model_params = config.get('model')
        model_params.update({"on_load": True})
        config.update({'model': model_params})
        return super()._load(config=config,
                             serialization_dir=serialization_dir,
                             weights_file=weights_file,
                             cuda_device=cuda_device,
                             **kwargs)
Exemplo n.º 8
0
class BidafPlusSelfAttention(Model):

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 mask_lstms: bool = True) -> None:
        super(BidafPlusSelfAttention, self).__init__(vocab)

        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._matrix_attention = TriLinearAttention(200)

        self._merge_atten = TimeDistributed(torch.nn.Linear(200 * 4, 200))

        self._residual_encoder = residual_encoder
        self._self_atten = TriLinearAttention(200)
        self._merge_self_atten = TimeDistributed(torch.nn.Linear(200 * 3, 200))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(200, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(200, 1))

        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._official_em = Average()
        self._official_f1 = Average()
        if dropout > 0:
            # self._dropout = torch.nn.Dropout(p=dropout)
            self._dropout = VariationalDropout(p=dropout)
        else:
            raise ValueError()
            # self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._dropout(self._residual_encoder(self._dropout(final_merged_passage), passage_mask))
        self_atten_matrix = self._self_atten(residual_layer, residual_layer)

        mask = passage_mask.resize(batch_size, passage_length, 1) * passage_mask.resize(batch_size, 1, passage_length)

        # torch.eye does not have a gpu implementation, so we are forced to use the cpu one and .cuda()
        # Not sure if this matters for performance
        self_mask = Variable(torch.eye(passage_length, passage_length).cuda()).resize(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_atten_probs = util.last_dim_softmax(self_atten_matrix, mask)

        # Batch matrix multiplication:
        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_atten_vecs = torch.matmul(self_atten_probs, residual_layer)

        residual_layer = F.relu(self._merge_self_atten(torch.cat(
            [self_atten_vecs, residual_layer, residual_layer * self_atten_vecs], dim=-1)))

        final_merged_passage += residual_layer

        final_merged_passage = self._dropout(final_merged_passage)

        start_rep = self._span_start_encoder(final_merged_passage, passage_lstm_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), passage_lstm_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)

        best_span = self._get_best_span(span_start_logits, span_end_logits)


        output_dict = {"span_start_logits": span_start_logits,
                       "span_start_probs": span_start_probs,
                       "span_end_logits": span_end_logits,
                       "span_end_probs": span_end_probs,
                       "best_span": best_span}
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
        if metadata is not None:
            output_dict['best_span_str'] = []
            for i in range(batch_size):
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                exact_match = f1_score = 0
                if answer_texts:
                    exact_match = squad_eval.metric_max_over_ground_truths(
                            squad_eval.exact_match_score,
                            best_span_string,
                            answer_texts)
                    f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score,
                            best_span_string,
                            answer_texts)
                self._official_em(100 * exact_match)
                self._official_f1(100 * f1_score)
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': self._official_em.get_metric(reset),
                'f1': self._official_f1.get_metric(reset),
                }

    @staticmethod
    def _get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = Variable(span_start_logits.data.new()
                                  .resize_(batch_size, 2).fill_(0)).long()

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BidirectionalAttentionFlow':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params)
        phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer"))
        residual_encoder = Seq2SeqEncoder.from_params(params.pop("residual_encoder"))
        span_start_encoder = Seq2SeqEncoder.from_params(params.pop("span_start_encoder"))
        span_end_encoder = Seq2SeqEncoder.from_params(params.pop("span_end_encoder"))
        initializer = InitializerApplicator.from_params(params.pop("initializer", []))
        dropout = params.pop('dropout', 0.2)

        # TODO: Remove the following when fully deprecated
        evaluation_json_file = params.pop('evaluation_json_file', None)
        if evaluation_json_file is not None:
            logger.warning("the 'evaluation_json_file' model parameter is deprecated, please remove")

        mask_lstms = params.pop('mask_lstms', True)
        params.assert_empty(cls.__name__)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   phrase_layer=phrase_layer,
                   residual_encoder=residual_encoder,
                   span_start_encoder=span_start_encoder,
                   span_end_encoder=span_end_encoder,
                   initializer=initializer,
                   dropout=dropout,
                   mask_lstms=mask_lstms)
class TransformerQA(Model):
    """
    This class implements a reading comprehension model patterned after the proposed model in
    https://arxiv.org/abs/1810.04805 (Devlin et al), with improvements borrowed from the SQuAD model in the
    transformers project.

    It predicts start tokens and end tokens with a linear layer on top of word piece embeddings.

    Note that the metrics that the model produces are calculated on a per-instance basis only. Since there could
    be more than one instance per question, these metrics are not the official numbers on the SQuAD task. To get
    official numbers, run the script in scripts/transformer_qa_eval.py.

    Parameters
    ----------
    vocab : ``Vocabulary``
    transformer_model_name : ``str``, optional (default=``bert-base-cased``)
        This model chooses the embedder according to this setting. You probably want to make sure this is set to
        the same thing as the reader.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 transformer_model_name: str = "bert-base-cased",
                 hidden_size=768,
                 **kwargs) -> None:
        super().__init__(vocab, **kwargs)
        self._text_field_embedder = BasicTextFieldEmbedder({
            "tokens":
            PretrainedTransformerEmbedder(transformer_model_name,
                                          hidden_size=hidden_size,
                                          task="QA")
        })
        self._linear_layer = nn.Linear(
            self._text_field_embedder.get_output_dim(), 2)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._per_instance_metrics = SquadEmAndF1()

    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        answer_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question_with_context : Dict[str, torch.LongTensor]
            From a ``TextField``. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and
            [SEP]) has type id 1.
        context_span : ``torch.IntTensor``
            From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come.
        answer_span : ``torch.IntTensor``, optional
            From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the ``metadata`` list should be the
            batch size, and each dictionary should have the keys ``id``, ``question``, ``context``,
            ``question_tokens``, ``context_tokens``, and ``answers``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        best_span_scores : torch.FloatTensor
            The score for each of the best spans.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._text_field_embedder(question_with_context)
        logits = self._linear_layer(embedded_question)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        span_start_logits = span_start_logits.squeeze(-1)
        span_end_logits = span_end_logits.squeeze(-1)

        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context),
            dtype=torch.bool)
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start:end + 1] = True

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       possible_answer_mask,
                                                       -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     possible_answer_mask,
                                                     -1e32)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
        best_spans = get_best_span(span_start_logits, span_end_logits)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_spans,
            "best_span_scores": best_span_scores,
        }

        # Compute the loss for training.
        if answer_span is not None:
            span_start = answer_span[:, 0]
            span_end = answer_span[:, 1]
            span_mask = span_start != -1
            self._span_accuracy(best_spans, answer_span,
                                span_mask.unsqueeze(-1).expand_as(best_spans))

            start_loss = cross_entropy(span_start_logits,
                                       span_start,
                                       ignore_index=-1)
            if torch.any(start_loss > 1e9):
                logger.critical("Start loss too high (%r)", start_loss)
                logger.critical("span_start_logits: %r", span_start_logits)
                logger.critical("span_start: %r", span_start)
                assert False

            end_loss = cross_entropy(span_end_logits,
                                     span_end,
                                     ignore_index=-1)
            if torch.any(end_loss > 1e9):
                logger.critical("End loss too high (%r)", end_loss)
                logger.critical("span_end_logits: %r", span_end_logits)
                logger.critical("span_end: %r", span_end)
                assert False

            loss = (start_loss + end_loss) / 2

            self._span_start_accuracy(span_start_logits, span_start, span_mask)
            self._span_end_accuracy(span_end_logits, span_end, span_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            best_spans = best_spans.detach().cpu().numpy()

            output_dict["best_span_str"] = []
            context_tokens = []
            for metadata_entry, best_span in zip(metadata, best_spans):
                context_tokens_for_question = metadata_entry["context_tokens"]
                context_tokens.append(context_tokens_for_question)

                best_span -= 1 + len(metadata_entry["question_tokens"]) + 2
                assert np.all(best_span >= 0)

                predicted_start, predicted_end = tuple(best_span)

                while (predicted_start >= 0
                       and context_tokens_for_question[predicted_start].idx is
                       None):
                    predicted_start -= 1
                if predicted_start < 0:
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index "
                        f"'{best_span[0]}' to an offset in the original text.")
                    character_start = 0
                else:
                    character_start = context_tokens_for_question[
                        predicted_start].idx

                while (predicted_end < len(context_tokens_for_question) and
                       context_tokens_for_question[predicted_end].idx is None):
                    predicted_end += 1
                if predicted_end >= len(context_tokens_for_question):
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index "
                        f"'{best_span[1]}' to an offset in the original text.")
                    character_end = len(metadata_entry["context"])
                else:
                    end_token = context_tokens_for_question[predicted_end]
                    character_end = end_token.idx + len(
                        sanitize_wordpiece(end_token.text))

                best_span_string = metadata_entry["context"][
                    character_start:character_end]
                output_dict["best_span_str"].append(best_span_string)

                answers = metadata_entry.get("answers")
                if len(answers) > 0:
                    self._per_instance_metrics(best_span_string, answers)
            output_dict["context_tokens"] = context_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._per_instance_metrics.get_metric(reset)
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "per_instance_em": exact_match,
            "per_instance_f1": f1_score,
        }
Exemplo n.º 10
0
def main(
    gpu: int,
    qa_model_path: str,
    paragraphs_source: str,
    generated_decompositions_paths: Optional[str],
    data: str,
    output_predictions_file: str,
    output_metrics_file: str,
    overrides="{}",
):
    import_module_and_submodules("src")

    overrides_dict = {}
    overrides_dict.update(json.loads(overrides))
    archive = load_archive(qa_model_path, cuda_device=gpu, overrides=json.dumps(overrides_dict))
    predictor = Predictor.from_archive(archive)

    dataset_reader = StrategyQAReader(
        paragraphs_source=paragraphs_source,
        generated_decompositions_paths=generated_decompositions_paths,
    )

    accuracy = BooleanAccuracy()
    last_logged_scores_time = time.monotonic()

    logger.info("Reading the dataset:")
    logger.info("Reading file at %s", data)
    dataset = None
    with open(data, mode="r", encoding="utf-8") as dataset_file:
        dataset = json.load(dataset_file)

    output_dataset = []
    for json_obj in tqdm(dataset):
        item = dataset_reader.json_to_item(json_obj)
        decomposition = item["decomposition"]
        generated_decomposition = item["generated_decomposition"]
        gold_answer = torch.tensor(item["answer"]).view((1,))

        used_decomposition = deepcopy(
            generated_decomposition
            if "generated_decomposition" in paragraphs_source
            else decomposition
        )

        # Per instance:
        # Until the final step has an answer, find in each iteration
        # all of the steps that are required to answer the last step (including by proxy)
        # and don't have references in them.
        # If it is not possible, return a score of zero for the instance.
        # If it is possible, retrieve paragraphs for these steps,
        # and then pass the step and the paragraphs for it to be answered by the model.
        # Replace the answer in all of the steps that has a reference for it.

        step_answers = [None for i in range(len(used_decomposition))]
        while True:
            reachability = get_reachability([step["question"] for step in used_decomposition])
            if reachability is None:
                break

            if step_answers[-1] is not None:
                break

            indices_of_interest = []
            if (sum(reachability[-1])) != 0:
                for i, reachable in enumerate(reachability[-1]):
                    if reachable > 0 and sum(reachability[i]) == 0:
                        indices_of_interest.append(i)
            else:
                indices_of_interest.append(len(step_answers) - 1)

            paragraphs = dataset_reader.get_paragraphs(
                decomposition=[used_decomposition[i] for i in indices_of_interest],
            )
            if paragraphs is not None:
                paragraphs_per_step_of_interest = paragraphs["per_step"]
            else:
                paragraphs_per_step_of_interest = [[{"content": " "}] for i in indices_of_interest]

            for i in indices_of_interest:
                step_answers[i] = get_answer(
                    predictor=predictor,
                    question=used_decomposition[i]["question"],
                    paragraphs=paragraphs_per_step_of_interest[indices_of_interest.index(i)],
                    force_yes_no=i == len(step_answers) - 1,
                )  # Return the best non-empty answer

            for i, step in enumerate(used_decomposition):
                used_decomposition[i]["question"] = fill_in_references(
                    step["question"], step_answers
                )

        predicted_answer_str = step_answers[-1].lower() if step_answers[-1] is not None else None
        if predicted_answer_str == "yes" or predicted_answer_str == "no":
            # Valid answer, the metric should be updated accordingly
            predicted_answer = torch.tensor(predicted_answer_str == "yes").view((1,))
            accuracy(predicted_answer, gold_answer)
        else:
            # Invalid answer, the metric should be updated with a mistake
            accuracy(not gold_answer, gold_answer)

        if time.monotonic() - last_logged_scores_time > 3:
            metrics_dict = {"accuracy": accuracy.get_metric()}
            logger.info(json.dumps(metrics_dict))
            last_logged_scores_time = time.monotonic()

        output_json_obj = deepcopy(json_obj)
        output_json_obj["decomposition"] = [step["question"] for step in used_decomposition]
        output_json_obj["step_answers"] = step_answers
        output_dataset.append(output_json_obj)

    if output_predictions_file is not None:
        with open(output_predictions_file, "w", encoding="utf-8") as f:
            json.dump(output_dataset, f, ensure_ascii=False, indent=4)

    metrics_dict = {"accuracy": accuracy.get_metric(reset=True)}
    if output_metrics_file is None:
        print(json.dumps(metrics_dict))
    else:
        with open(output_metrics_file, "w", encoding="utf-8") as f:
            json.dump(
                metrics_dict,
                f,
                ensure_ascii=False,
                indent=4,
            )
Exemplo n.º 11
0
class ESIMCosine(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 projection_feedforward: FeedForward,
                 inference_encoder: Seq2SeqEncoder,
                 output_feedforward: FeedForwardPair,
                 dropout: float = 0.5,
                 margin: float = 1.25,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._encoder = encoder

        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._projection_feedforward = projection_feedforward

        self._inference_encoder = inference_encoder

        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.rnn_input_dropout = InputVariationalDropout(dropout)
        else:
            self.dropout = None
            self.rnn_input_dropout = None

        self._output_feedforward = output_feedforward

        self._margin = margin

        self._accuracy = BooleanAccuracy()

        initializer(self)

    @overrides
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        # Shape: (batch_size, seq_length, embedding_dim)
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)

        mask_premise = get_text_field_mask(premise).float()
        mask_hypothesis = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        # Shape: (batch_size, seq_length, encoding_direction_num * encoding_hidden_dim)
        encoded_premise = self._encoder(embedded_premise, mask_premise)
        encoded_hypothesis = self._encoder(embedded_hypothesis,
                                           mask_hypothesis)

        # Shape: (batch_size, p_length, h_length)
        similarity_matrix = self._matrix_attention(encoded_premise,
                                                   encoded_hypothesis)

        # Shape: (batch_size, p_length, h_length)
        p2h_attention = masked_softmax(similarity_matrix, mask_hypothesis)
        # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, h_length, p_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), mask_premise)
        # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching)
        enhanced_premise = torch.cat([
            encoded_premise, attended_hypothesis, encoded_premise -
            attended_hypothesis, encoded_premise * attended_hypothesis
        ],
                                     dim=-1)
        # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching)
        enhanced_hypothesis = torch.cat([
            encoded_hypothesis, attended_premise, encoded_hypothesis -
            attended_premise, encoded_hypothesis * attended_premise
        ],
                                        dim=-1)

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        # Shape: (batch_size, seq_length, projection_hidden_dim)
        projected_enhanced_premise = self._projection_feedforward(
            enhanced_premise)
        projected_enhanced_hypothesis = self._projection_feedforward(
            enhanced_hypothesis)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(
                projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(
                projected_enhanced_hypothesis)

        # Shape: (batch_size, seq_length, inference_direction_num * inference_hidden_dim)
        inferenced_premise = self._inference_encoder(
            projected_enhanced_premise, mask_premise)
        inferenced_hypothesis = self._inference_encoder(
            projected_enhanced_hypothesis, mask_hypothesis)

        # The pooling layer -- max and avg pooling.
        # Shape: (batch_size, inference_direction_num * inference_hidden_dim)
        pooled_premise_max, _ = replace_masked_values(
            inferenced_premise, mask_premise.unsqueeze(-1), -1e7).max(dim=1)
        pooled_hypothesis_max, _ = replace_masked_values(
            inferenced_hypothesis, mask_hypothesis.unsqueeze(-1),
            -1e7).max(dim=1)

        pooled_premise_avg = torch.sum(
            inferenced_premise * mask_premise.unsqueeze(-1),
            dim=1) / torch.sum(mask_premise, 1, keepdim=True)
        pooled_hypothesis_avg = torch.sum(
            inferenced_hypothesis * mask_hypothesis.unsqueeze(-1),
            dim=1) / torch.sum(mask_hypothesis, 1, keepdim=True)

        # Now concat
        # Shape: (batch_size, inference_direction_num * inference_hidden_dim * 2)
        pooled_premise_all = torch.cat(
            [pooled_premise_avg, pooled_premise_max], dim=1)
        pooled_hypothesis_all = torch.cat(
            [pooled_hypothesis_avg, pooled_hypothesis_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            pooled_premise_all = self.dropout(pooled_premise_all)
            pooled_hypothesis_all = self.dropout(pooled_hypothesis_all)

        # Shape: (batch_size, output_feedforward_hidden_dim)
        output_premise, output_hypothesis = self._output_feedforward(
            pooled_premise_all, pooled_hypothesis_all)

        distance = F.pairwise_distance(output_premise, output_hypothesis)
        prediction = distance < (self._margin / 2.0)
        output_dict = {'distance': distance, "prediction": prediction}

        if label is not None:
            """
            Contrastive loss function.
            Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
            """
            y = label.float()
            l1 = y * torch.pow(distance, 2) / 2.0
            l2 = (1 - y) * torch.pow(
                torch.clamp(self._margin - distance, min=0.0), 2) / 2.0
            loss = torch.mean(l1 + l2)

            self._accuracy(prediction, label.byte())

            output_dict["loss"] = loss

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self._accuracy.get_metric(reset)}
Exemplo n.º 12
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    attention_similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    initializer : ``InitializerApplicator``
        We will use this to initialize the parameters in the model, calling ``initializer(self)``.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    evaluation_json_file : ``str``, optional
        If given, we will load this JSON into memory and use it to compute official metrics
        against.  We need this separately from the validation dataset, because the official metrics
        use all of the annotations, while our dataset reader picks the most frequent one.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 mask_lstms: bool = True) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
                                                      num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))
        initializer(self)

        # Bidaf has lots of layer dimensions which need to match up - these
        # aren't necessarily obvious from the configuration files, so we check
        # here.
        if modeling_layer.get_input_dim() != 4 * encoding_dim:
            raise ConfigurationError("The input dimension to the modeling_layer must be "
                                     "equal to 4 times the encoding dimension of the phrase_layer. "
                                     "Found {} and 4 * {} respectively.".format(modeling_layer.get_input_dim(),
                                                                                encoding_dim))
        if text_field_embedder.get_output_dim() != phrase_layer.get_input_dim():
            raise ConfigurationError("The output dimension of the text_field_embedder (embedding_dim + "
                                     "char_cnn) must match the input dimension of the phrase_encoder. "
                                     "Found {} and {}, respectively.".format(text_field_embedder.get_output_dim(),
                                                                             phrase_layer.get_input_dim()))

        if span_end_encoder.get_input_dim() != encoding_dim * 4 + modeling_dim * 3:
            raise ConfigurationError("The input dimension of the span_end_encoder should be equal to "
                                     "4 * phrase_layer.output_dim + 3 * modeling_layer.output_dim. "
                                     "Found {} and (4 * {} + 3 * {}) "
                                     "respectively.".format(span_end_encoder.get_input_dim(),
                                                            encoding_dim,
                                                            modeling_dim))

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self._get_best_span(span_start_logits, span_end_logits)

        output_dict = {"span_start_logits": span_start_logits,
                       "span_start_probs": span_start_probs,
                       "span_end_logits": span_end_logits,
                       "span_end_probs": span_end_probs,
                       "best_span": best_span}
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
        if metadata is not None:
            output_dict['best_span_str'] = []
            for i in range(batch_size):
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
                }

    @staticmethod
    def _get_best_span(span_start_logits: Variable, span_end_logits: Variable) -> Variable:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = Variable(span_start_logits.data.new()
                                  .resize_(batch_size, 2).fill_(0)).long()

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BidirectionalAttentionFlow':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(vocab, embedder_params)
        num_highway_layers = params.pop("num_highway_layers")
        phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer"))
        similarity_function = SimilarityFunction.from_params(params.pop("similarity_function"))
        modeling_layer = Seq2SeqEncoder.from_params(params.pop("modeling_layer"))
        span_end_encoder = Seq2SeqEncoder.from_params(params.pop("span_end_encoder"))
        initializer = InitializerApplicator.from_params(params.pop("initializer", []))
        dropout = params.pop('dropout', 0.2)

        # TODO: Remove the following when fully deprecated
        evaluation_json_file = params.pop('evaluation_json_file', None)
        if evaluation_json_file is not None:
            logger.warning("the 'evaluation_json_file' model parameter is deprecated, please remove")

        mask_lstms = params.pop('mask_lstms', True)
        params.assert_empty(cls.__name__)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   num_highway_layers=num_highway_layers,
                   phrase_layer=phrase_layer,
                   attention_similarity_function=similarity_function,
                   modeling_layer=modeling_layer,
                   span_end_encoder=span_end_encoder,
                   initializer=initializer,
                   dropout=dropout,
                   mask_lstms=mask_lstms)
Exemplo n.º 13
0
class DocLevelmpeEsim(mpeEsim):
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            encoder: Seq2SeqEncoder,
            projection_feedforward: FeedForward,
            inference_encoder: Seq2SeqEncoder,
            output_feedforward: FeedForward,
            output_logit: FeedForward,
            final_feedforward: FeedForward,
            coverage_loss: CoverageLoss,
            similarity_function: SimilarityFunction = DotProductSimilarity(),
            dropout: float = 0.5,
            contextualize_pair_comparators: bool = False,
            pair_context_encoder: Seq2SeqEncoder = None,
            pair_feedforward: FeedForward = None,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None) -> None:
        # Need to send it verbatim because otherwise FromParams doesn't work appropriately.
        super().__init__(
            vocab=vocab,
            text_field_embedder=text_field_embedder,
            encoder=encoder,
            similarity_function=similarity_function,
            projection_feedforward=projection_feedforward,
            inference_encoder=inference_encoder,
            output_feedforward=output_feedforward,
            output_logit=output_logit,
            final_feedforward=final_feedforward,
            contextualize_pair_comparators=contextualize_pair_comparators,
            coverage_loss=coverage_loss,
            pair_context_encoder=pair_context_encoder,
            pair_feedforward=pair_feedforward,
            dropout=dropout,
            initializer=initializer,
            regularizer=regularizer)
        self._answer_loss = torch.nn.BCELoss()
        self.max_sent_count = 120
        self.fc1 = torch.nn.Linear(self.max_sent_count, 10)
        self.fc2 = torch.nn.Linear(10, 5)
        self.fc3 = torch.nn.Linear(5, 1)
        self.out_sigmoid = torch.nn.Sigmoid()

        self._accuracy = BooleanAccuracy()

    @overrides
    def forward(
        self,  # type: ignore
        premises: Dict[str, torch.LongTensor],
        hypotheses: Dict[str, torch.LongTensor],
        paragraph: Dict[str, torch.LongTensor],
        answer_index: torch.LongTensor = None,
        relevance_presence_mask: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        hypothesis_list = unbind_tensor_dict(hypotheses, dim=1)

        label_logits = []
        premises_attentions = []
        premises_aggregation_attentions = []
        #coverage_losses = []
        for hypothesis in hypothesis_list:  # single hypothesis even to the parent class
            #print("super().forward",len(premises), len(hypothesis), len(paragraph))
            output_dict = super().forward(premises=premises,
                                          hypothesis=hypothesis,
                                          paragraph=paragraph)  #paragraph?
            individual_logit = output_dict["label_logits"][:, self._label2idx[
                "entailment"]]  # only useful key
            label_logits.append(individual_logit)
            #
            premises_attention = output_dict.get("premises_attention", None)
            premises_attentions.append(premises_attention)
            premises_aggregation_attention = output_dict.get(
                "premises_aggregation_attention", None)
            premises_aggregation_attentions.append(
                premises_aggregation_attention)
            #if relevance_presence_mask is not None:
            #coverage_loss = output_dict["coverage_loss"]
            #coverage_losses.append(coverage_loss)
            del output_dict, individual_logit, premises_attention, premises_aggregation_attention

        label_logits = torch.stack(label_logits, dim=-1)
        premises_attentions = torch.stack(premises_attentions, dim=1)
        premises_aggregation_attentions = torch.stack(
            premises_aggregation_attentions, dim=1)
        #if relevance_presence_mask is not None:
        #coverage_losses = torch.stack(coverage_losses, dim=0)

        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
        # @todo:  Check covaraince of label_logits and label_probs
        if label_logits.shape[1] < self.max_sent_count:
            label_logits = torch.nn.functional.pad(
                input=label_logits,
                pad=(0, self.max_sent_count - label_logits.shape[1], 0, 0),
                mode='constant',
                value=0)

        single_output_logit = self.fc3(self.fc2(self.fc1(label_logits)))
        sigmoid_output = self.out_sigmoid(single_output_logit)
        #import pdb; pdb.set_trace()

        output_dict = {
            "label_logits": single_output_logit,
            "label_probs": sigmoid_output,
            "premises_attentions": premises_attentions,
            "premises_aggregation_attentions": premises_aggregation_attentions
        }

        if answer_index is not None:
            #print("_answer_loss",single_output_logit, answer_index)
            cudadevice = single_output_logit.device  # torch.device('cuda:'+ str(single_output_logit.get_device()))
            temp_tensor = torch.tensor([[k]
                                        for k in answer_index]).to(cudadevice)
            sgd = torch.nn.Sigmoid()
            loss = self._answer_loss(sgd(single_output_logit),
                                     sgd(temp_tensor.float()))
            output_dict["loss"] = loss
            output_dict["novelty"] = (single_output_logit > 0.5)
            temp_tensor = torch.tensor([[k] for k in answer_index])
            #print("_answer_loss",single_output_logit, temp_tensor)
            self._accuracy(single_output_logit > 0.5, temp_tensor.byte())
            del temp_tensor, loss, cudadevice

            #self._accuracy(single_output_logit>0.5, answer_index)
        del label_logits, label_probs, hypothesis_list,
        # if answer_index is not None:
        # answer_loss
        # loss = self._answer_loss(label_logits, answer_index)
        # coverage loss
        # if relevance_presence_mask is not None:
        #     loss += coverage_losses.mean()
        # output_dict["loss"] = loss

        # self._accuracy(label_logits, answer_index)

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        accuracy_metric = self._accuracy.get_metric(reset)
        return {'accuracy': accuracy_metric}
if span_start is not None:
    span_start_loss = nll_loss(
        util.masked_log_softmax(span_start_logits, passage_mask),
        span_start.squeeze(-1))
    span_end_loss = nll_loss(
        util.masked_log_softmax(span_end_logits, passage_mask),
        span_end.squeeze(-1))
    loss = span_start_loss + span_end_loss

    span_start_accuracy_function(span_start_logits, span_start.squeeze(-1))
    span_end_accuracy_function(span_end_logits, span_end.squeeze(-1))
    span_accuracy_function(best_span, torch.stack([span_start, span_end], -1))

    span_start_accuracy = span_start_accuracy_function.get_metric()
    span_end_accuracy = span_end_accuracy_function.get_metric()
    span_accuracy = span_accuracy_function.get_metric()

    print("Loss: ", loss)
    print("span_start_accuracy: ", span_start_accuracy)
    print("span_start_accuracy: ", span_start_accuracy)
    print("span_end_accuracy: ", span_end_accuracy)

# Compute the EM and F1 on SQuAD and add the tokenized input to the output.
if metadata is not None:
    best_span_str = []
    question_tokens = []
    passage_tokens = []
    for i in range(batch_size):
        question_tokens.append(metadata[i]['question_tokens'])
        passage_tokens.append(metadata[i]['passage_tokens'])
        passage_str = metadata[i]['original_passage']
Exemplo n.º 15
0
class GMN(Model):
    def __init__(self, args, word_embeddings: TextFieldEmbedder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)

        # parameters
        self.args = args
        self.word_embeddings = word_embeddings

        # gate
        self.W_z = nn.Linear(self.args.embedding_size, 1, bias=False)
        self.U_z = nn.Linear(self.args.embedding_size, 1, bias=False)
        self.W_r = nn.Linear(self.args.embedding_size, 1, bias=False)
        self.U_r = nn.Linear(self.args.embedding_size, 1, bias=False)
        self.W = nn.Linear(self.args.embedding_size, 1, bias=False)
        self.U = nn.Linear(self.args.embedding_size, 1, bias=False)

        # layers
        self.event_embedding = EventEmbedding(args, self.word_embeddings)
        self.attention = Attention(self.args.embedding_size,
                                   score_function='mlp')
        self.sigmoid = Sigmoid()
        self.tanh = Tanh()
        self.score = Score(self.args.embedding_size,
                           self.args.embedding_size,
                           threshold=self.args.threshold)

        # metrics
        self.accuracy = BooleanAccuracy()
        self.f1_score = F1Measure(positive_label=1)
        self.loss_function = BCELoss()

    def gated_atten(self, vt_1, atten_input):
        """
        gated attention block
        :param vt_1: v_t-1
        :param atten_input: [h1, h2, ... ,h_n-1]
        :return: v_t
        """
        # [batch_size, 1, embedding_size]
        out_at, _ = self.attention(atten_input, vt_1)
        # [batch_size, embedding_size]
        h_e = torch.sum(out_at * atten_input, dim=1)
        # [batch_size, 1]
        z = (self.sigmoid(self.W_z(h_e.unsqueeze(1)) +
                          self.U_z(vt_1))).squeeze(1)
        # [batch_size, 1]
        r = (self.sigmoid(self.W_r(h_e.unsqueeze(1)) +
                          self.U_r(vt_1))).squeeze(1)
        # [batch_size, 1]
        h = self.tanh(
            self.W(h_e.unsqueeze(1)) +
            self.U((torch.mul(r, vt_1.squeeze(1))).unsqueeze(1))).squeeze(1)
        # [baych_size, 1, embedding_size]
        vt = (torch.mul(
            (1 - z), vt_1.squeeze(1)) + torch.mul(z, h)).unsqueeze(1)

        return vt

    @overrides
    def forward(self,
                trigger_0: Dict[str, torch.LongTensor],
                trigger_agent_0: Dict[str, torch.LongTensor],
                agent_attri_0: Dict[str, torch.LongTensor],
                trigger_object_0: Dict[str, torch.LongTensor],
                object_attri_0: Dict[str, torch.LongTensor],
                trigger_1: Dict[str, torch.LongTensor],
                trigger_agent_1: Dict[str, torch.LongTensor],
                agent_attri_1: Dict[str, torch.LongTensor],
                trigger_object_1: Dict[str, torch.LongTensor],
                object_attri_1: Dict[str, torch.LongTensor],
                trigger_2: Dict[str, torch.LongTensor],
                trigger_agent_2: Dict[str, torch.LongTensor],
                agent_attri_2: Dict[str, torch.LongTensor],
                trigger_object_2: Dict[str, torch.LongTensor],
                object_attri_2: Dict[str, torch.LongTensor],
                trigger_3: Dict[str, torch.LongTensor],
                trigger_agent_3: Dict[str, torch.LongTensor],
                agent_attri_3: Dict[str, torch.LongTensor],
                trigger_object_3: Dict[str, torch.LongTensor],
                object_attri_3: Dict[str, torch.LongTensor],
                trigger_4: Dict[str, torch.LongTensor],
                trigger_agent_4: Dict[str, torch.LongTensor],
                agent_attri_4: Dict[str, torch.LongTensor],
                trigger_object_4: Dict[str, torch.LongTensor],
                object_attri_4: Dict[str, torch.LongTensor],
                event_type: Dict[str, torch.LongTensor],
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        # tri, e: [batch_size, 1, embedding_size]
        tri0, e0 = self.event_embedding(trigger_0, trigger_agent_0,
                                        trigger_object_0)
        tri1, e1 = self.event_embedding(trigger_1, trigger_agent_1,
                                        trigger_object_1)
        tri2, e2 = self.event_embedding(trigger_2, trigger_agent_2,
                                        trigger_object_2)
        tri3, e3 = self.event_embedding(trigger_3, trigger_agent_3,
                                        trigger_object_3)
        tri4, e4 = self.event_embedding(trigger_4, trigger_agent_4,
                                        trigger_object_4)

        # [batch_size, seq_Len, embedding_size]
        e = (torch.stack([e0, e1, e2, e3, e4], dim=1)).squeeze(2)

        # [batch_size, 1, embedding_size]
        vt = tri4

        for i in range(self.args.hop_num):
            # [batch_size, 1, embedding_size]
            vt = self.gated_atten(vt, e)

        # [batch_size, embedding_size]
        x = vt.view(vt.size(0), -1)
        # [batch_size, 1] , [batch_size], [batch_size, label_size]
        score, logits, logits_f1 = self.score(x, tri4)

        output = {"logits": logits, "score": score}
        if label is not None:
            self.accuracy(logits, label)
            self.f1_score(logits_f1, label)
            output["loss"] = self.loss_function(score.squeeze(1),
                                                label.float())

        return output

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        accuracy = self.accuracy.get_metric(reset)
        precision, recall, f1_measure = self.f1_score.get_metric(reset)
        return {
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_measure": f1_measure
        }
Exemplo n.º 16
0
class BidirectionalAttentionFlow_1(Model):
    """
    This class implements a Bayesian version of Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).
    """
    
    def __init__(self, vocab: Vocabulary, cf_a, preloaded_elmo = None) -> None:
        super(BidirectionalAttentionFlow_1, self).__init__(vocab, cf_a.regularizer)
        
        """
        Initialize some data structures 
        """
        self.cf_a = cf_a
        # Bayesian data models
        self.VBmodels = []
        self.LinearModels = []
        """
        ############## TEXT FIELD EMBEDDER with ELMO ####################
        text_field_embedder : ``TextFieldEmbedder``
            Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
        """
        if (cf_a.use_ELMO):
            if (type(preloaded_elmo) != type(None)):
                text_field_embedder = preloaded_elmo
            else:
                text_field_embedder = bidut.download_Elmo(cf_a.ELMO_num_layers, cf_a.ELMO_droput )
                print ("ELMO loaded from disk or downloaded")
        else:
            text_field_embedder = None
        
#        embedder_out_dim  = text_field_embedder.get_output_dim()
        self._text_field_embedder = text_field_embedder
        
        if(cf_a.Add_Linear_projection_ELMO):
            if (self.cf_a.VB_Linear_projection_ELMO):
                prior = Vil.Prior(**(cf_a.VB_Linear_projection_ELMO_prior))
                print ("----------------- Bayesian Linear Projection ELMO --------------")
                linear_projection_ELMO =  LinearVB(text_field_embedder.get_output_dim(), 200, prior = prior)
                self.VBmodels.append(linear_projection_ELMO)
            else:
                linear_projection_ELMO = torch.nn.Linear(text_field_embedder.get_output_dim(), 200)
        
            self._linear_projection_ELMO = linear_projection_ELMO
            
        """
        ############## Highway layers ####################
        num_highway_layers : ``int``
            The number of highway layers to use in between embedding the input and passing it through
            the phrase layer.
        """
        
        Input_dimension_highway = None
        if (cf_a.Add_Linear_projection_ELMO):
            Input_dimension_highway = 200
        else:
            Input_dimension_highway = text_field_embedder.get_output_dim()
            
        num_highway_layers = cf_a.num_highway_layers
        # Linear later to compute the start 
        if (self.cf_a.VB_highway_layers):
            print ("----------------- Bayesian Highway network  --------------")
            prior = Vil.Prior(**(cf_a.VB_highway_layers_prior))
            highway_layer = HighwayVB(Input_dimension_highway,
                                                          num_highway_layers, prior = prior)
            self.VBmodels.append(highway_layer)
        else:
            
            highway_layer = Highway(Input_dimension_highway,
                                                          num_highway_layers)
        highway_layer = TimeDistributed(highway_layer)
        
        self._highway_layer = highway_layer
        
        """
        ############## Phrase layer ####################
        phrase_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between embedding tokens
            and doing the bidirectional attention.
        """
        if cf_a.phrase_layer_dropout > 0:       ## Create dropout layer
            dropout_phrase_layer = torch.nn.Dropout(p=cf_a.phrase_layer_dropout)
        else:
            dropout_phrase_layer = lambda x: x
        
        phrase_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(Input_dimension_highway, hidden_size = cf_a.phrase_layer_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.phrase_layer_num_layers, dropout = cf_a.phrase_layer_dropout))
        
        phrase_encoding_out_dim = cf_a.phrase_layer_hidden_size * 2
        self._phrase_layer = phrase_layer
        self._dropout_phrase_layer = dropout_phrase_layer
        
        """
        ############## Matrix attention layer ####################
        similarity_function : ``SimilarityFunction``
            The similarity function that we will use when comparing encoded passage and question
            representations.
        """
        
        # Linear later to compute the start 
        if (self.cf_a.VB_similarity_function):
            prior = Vil.Prior(**(cf_a.VB_similarity_function_prior))
            print ("----------------- Bayesian Similarity matrix --------------")
            similarity_function = LinearSimilarityVB(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim, prior = prior)
            self.VBmodels.append(similarity_function)
        else:
            similarity_function = LinearSimilarity(
                  combination = "x,y,x*y",
                  tensor_1_dim =  phrase_encoding_out_dim,
                  tensor_2_dim = phrase_encoding_out_dim)
            
        matrix_attention = LegacyMatrixAttention(similarity_function)
        self._matrix_attention = matrix_attention
        
        """
        ############## Modelling Layer ####################
        modeling_layer : ``Seq2SeqEncoder``
            The encoder (with its own internal stacking) that we will use in between the bidirectional
            attention and predicting span start and end.
        """
        ## Create dropout layer
        if cf_a.modeling_passage_dropout > 0:       ## Create dropout layer
            dropout_modeling_passage = torch.nn.Dropout(p=cf_a.modeling_passage_dropout)
        else:
            dropout_modeling_passage = lambda x: x
        
        modeling_layer = PytorchSeq2SeqWrapper(torch.nn.LSTM(phrase_encoding_out_dim * 4, hidden_size = cf_a.modeling_passage_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_passage_num_layers, dropout = cf_a.modeling_passage_dropout))

        self._modeling_layer = modeling_layer
        self._dropout_modeling_passage = dropout_modeling_passage
        
        """
        ############## Span Start Representation #####################
        span_end_encoder : ``Seq2SeqEncoder``
            The encoder that we will use to incorporate span start predictions into the passage state
            before predicting span end.
        """
        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        
        # Linear later to compute the start 
        if (self.cf_a.VB_span_start_predictor_linear):
            prior = Vil.Prior(**(cf_a.VB_span_start_predictor_linear_prior))
            print ("----------------- Bayesian Span Start Predictor--------------")
            span_start_predictor_linear =  LinearVB(span_start_input_dim, 1, prior = prior)
            self.VBmodels.append(span_start_predictor_linear)
        else:
            span_start_predictor_linear = torch.nn.Linear(span_start_input_dim, 1)
            
        self._span_start_predictor_linear = span_start_predictor_linear
        self._span_start_predictor = TimeDistributed(span_start_predictor_linear)

        """
        ############## Span End Representation #####################
        """
        
        ## Create dropout layer
        if cf_a.span_end_encoder_dropout > 0:
            dropout_span_end_encode = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
        else:
            dropout_span_end_encode = lambda x: x
        
        span_end_encoder = PytorchSeq2SeqWrapper(torch.nn.LSTM(encoding_dim * 4 + modeling_dim * 3, hidden_size = cf_a.modeling_span_end_hidden_size, 
                                                   batch_first=True, bidirectional = True,
                                                   num_layers = cf_a.modeling_span_end_num_layers, dropout = cf_a.span_end_encoder_dropout))
   
        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        
        self._span_end_encoder = span_end_encoder
        self._dropout_span_end_encode = dropout_span_end_encode
        
        if (self.cf_a.VB_span_end_predictor_linear):
            print ("----------------- Bayesian Span End Predictor--------------")
            prior = Vil.Prior(**(cf_a.VB_span_end_predictor_linear_prior))
            span_end_predictor_linear = LinearVB(span_end_input_dim, 1, prior = prior)
            self.VBmodels.append(span_end_predictor_linear) 
        else:
            span_end_predictor_linear = torch.nn.Linear(span_end_input_dim, 1)
        
        self._span_end_predictor_linear = span_end_predictor_linear
        self._span_end_predictor = TimeDistributed(span_end_predictor_linear)

        """
        Dropput last layers
        """
        if cf_a.spans_output_dropout > 0:
            dropout_spans_output = torch.nn.Dropout(p=cf_a.span_end_encoder_dropout)
        else:
            dropout_spans_output = lambda x: x
        
        self._dropout_spans_output = dropout_spans_output
        
        """
        Checkings and accuracy
        """
        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(Input_dimension_highway , phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        """
        mask_lstms : ``bool``, optional (default=True)
            If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
            with only a slight performance decrease, if any.  We haven't experimented much with this
            yet, but have confirmed that we still get very similar performance with much faster
            training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
            required when using masking with pytorch LSTMs.
        """
        self._mask_lstms = cf_a.mask_lstms

    
        """
        ################### Initialize parameters ##############################
        """
        #### THEY ARE ALL INITIALIZED WHEN INSTANTING THE COMPONENTS ###
    
        """
        ####################### OPTIMIZER ################
        """
        optimizer = pytut.get_optimizers(self, cf_a)
        self._optimizer = optimizer
        #### TODO: Learning rate scheduler ####
        #scheduler = optim.ReduceLROnPlateau(optimizer, 'max')
    
    def forward_ensemble(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False) -> Dict[str, torch.Tensor]:
        """
        Sample 10 times and add them together
        """
        self.set_posterior_mean(True)
        most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)
        self.set_posterior_mean(False)
       
        subresults = [most_likely_output]
        for i in range(10):
           subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information))

        batch_size = len(subresults[0]["best_span"])

        best_span = bidut.merge_span_probs(subresults)
        
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
                        
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output
    
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False,
                get_attentions = False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        
        """
        #################### Sample Bayesian weights ##################
        """
        self.sample_posterior()
        
        """
        ################## MASK COMPUTING ########################
        """
                
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None
        
        """
        ###################### EMBEDDING + HIGHWAY LAYER ########################
        """
#        self.cf_a.use_ELMO
        
        if(self.cf_a.Add_Linear_projection_ELMO):
            embedded_question = self._highway_layer(self._linear_projection_ELMO (self._text_field_embedder(question['character_ids'])["elmo_representations"][-1]))
            embedded_passage = self._highway_layer(self._linear_projection_ELMO(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1]))
        else:
            embedded_question = self._highway_layer(self._text_field_embedder(question['character_ids'])["elmo_representations"][-1])
            embedded_passage = self._highway_layer(self._text_field_embedder(passage['character_ids'])["elmo_representations"][-1])
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        
        """
        ###################### phrase_layer LAYER ########################
        """

        encoded_question = self._dropout_phrase_layer(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout_phrase_layer(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        """
        ###################### Attention LAYER ########################
        """
        
        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout_modeling_passage(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)
        
        """
        ###################### Spans LAYER ########################
        """
        
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout_spans_output(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout_span_end_encode(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout_spans_output(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        
        best_span = bidut.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            
            span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            loss = span_start_loss + span_end_loss

            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            
            output_dict["loss"] = loss
            output_dict["span_start_loss"] = span_start_loss
            output_dict["span_end_loss"] = span_end_loss
            
        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            if (get_sample_level_information):
                output_dict["em_samples"] = []
                output_dict["f1_samples"] = []
                
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output_dict["em_samples"].append(em_sample)
                        output_dict["f1_samples"].append(f1_sample)
                        
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output_dict["span_start_sample_loss"] = []
            output_dict["span_end_sample_loss"] = []
            for i in range (batch_size):
                span_start_loss = nll_loss(util.masked_log_softmax(span_start_logits[[i],:], passage_mask[[i],:]), span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(util.masked_log_softmax(span_end_logits[[i],:], passage_mask[[i],:]), span_end.squeeze(-1)[[i]])
                
                output_dict["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output_dict["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        if(get_attentions):
            output_dict["C2Q_attention"] = passage_question_attention
            output_dict["Q2C_attention"] = question_passage_attention
            output_dict["simmilarity"] = passage_question_similarity
            
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
                }
    
    def train_batch(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        
        """
        It is enough to just compute the total loss because the normal weights 
        do not depend on the KL Divergence
        """
        # Now we can just compute both losses which will build the dynamic graph
        
        output = self.forward(question,passage,span_start,span_end,metadata )
        data_loss = output["loss"]
        
        KL_div = self.get_KL_divergence()
        total_loss =  self.combine_losses(data_loss, KL_div)
        
        self.zero_grad()     # zeroes the gradient buffers of all parameters
        total_loss.backward()
        
        if (type(self._optimizer) == type(None)):
            parameters = filter(lambda p: p.requires_grad, self.parameters())
            with torch.no_grad():
                for f in parameters:
                    f.data.sub_(f.grad.data * self.lr )
        else:
#            print ("Training")
            self._optimizer.step()
            self._optimizer.zero_grad()
            
        return output
    
    def fill_batch_training_information(self, training_logger,
                                        output_batch):
        """
        Function to fill the the training_logger for each batch. 
        training_logger: Dictionary that will hold all the training info
        output_batch: Output from training the batch
        """
        training_logger["train"]["span_start_loss_batch"].append(output_batch["span_start_loss"].detach().cpu().numpy())
        training_logger["train"]["span_end_loss_batch"].append(output_batch["span_end_loss"].detach().cpu().numpy())
        training_logger["train"]["loss_batch"].append(output_batch["loss"].detach().cpu().numpy())
        # Training metrics:
        metrics = self.get_metrics()
        training_logger["train"]["start_acc_batch"].append(metrics["start_acc"])
        training_logger["train"]["end_acc_batch"].append(metrics["end_acc"])
        training_logger["train"]["span_acc_batch"].append(metrics["span_acc"])
        training_logger["train"]["em_batch"].append(metrics["em"])
        training_logger["train"]["f1_batch"].append(metrics["f1"])
    
        
    def fill_epoch_training_information(self, training_logger,device,
                                        validation_iterable, num_batches_validation):
        """
        Fill the information per each epoch
        """
        Ntrials_CUDA = 100
        # Training Epoch final metrics
        metrics = self.get_metrics(reset = True)
        training_logger["train"]["start_acc"].append(metrics["start_acc"])
        training_logger["train"]["end_acc"].append(metrics["end_acc"])
        training_logger["train"]["span_acc"].append(metrics["span_acc"])
        training_logger["train"]["em"].append(metrics["em"])
        training_logger["train"]["f1"].append(metrics["f1"])
        
        self.set_posterior_mean(True)
        self.eval()
        
        data_loss_validation = 0
        loss_validation = 0
        with torch.no_grad():
            # Compute the validation accuracy by using all the Validation dataset but in batches.
            for j in range(num_batches_validation):
                tensor_dict = next(validation_iterable)
                
                trial_index = 0
                while (1):
                    try:
                        tensor_dict = pytut.move_to_device(tensor_dict, device) ## Move the tensor to cuda
                        output_batch = self.forward(**tensor_dict)
                        break;
                    except RuntimeError as er:
                        print (er.args)
                        torch.cuda.empty_cache()
                        time.sleep(5)
                        torch.cuda.empty_cache()
                        trial_index += 1
                        if (trial_index == Ntrials_CUDA):
                            print ("Too many failed trials to allocate in memory")
                            send_error_email(str(er.args))
                            sys.exit(0)
                
                data_loss_validation += output_batch["loss"].detach().cpu().numpy() 
                        
                ## Memmory management !!
            if (self.cf_a.force_free_batch_memory):
                del tensor_dict["question"]; del tensor_dict["passage"]
                del tensor_dict
                del output_batch
                torch.cuda.empty_cache()
            if (self.cf_a.force_call_garbage_collector):
                gc.collect()
                
            data_loss_validation = data_loss_validation/num_batches_validation
#            loss_validation = loss_validation/num_batches_validation
    
            # Training Epoch final metrics
        metrics = self.get_metrics(reset = True)
        training_logger["validation"]["start_acc"].append(metrics["start_acc"])
        training_logger["validation"]["end_acc"].append(metrics["end_acc"])
        training_logger["validation"]["span_acc"].append(metrics["span_acc"])
        training_logger["validation"]["em"].append(metrics["em"])
        training_logger["validation"]["f1"].append(metrics["f1"])
        
        training_logger["validation"]["data_loss"].append(data_loss_validation)
        self.train()
        self.set_posterior_mean(False)
    
    def trim_model(self, mu_sigma_ratio = 2):
        
        total_size_w = []
        total_removed_w = []
        total_size_b = []
        total_removed_b = []
        
        if (self.cf_a.VB_Linear_projection_ELMO):
                VBmodel = self._linear_projection_ELMO
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_highway_layers):
                VBmodel = self._highway_layer._module.VBmodels[0]
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_similarity_function):
                VBmodel = self._matrix_attention._similarity_function
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_span_start_predictor_linear):
                VBmodel = self._span_start_predictor_linear
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        if (self.cf_a.VB_span_end_predictor_linear):
                VBmodel = self._span_end_predictor_linear
                Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                size_w, removed_w, size_b, removed_b = Vil.trim_LinearVB_weights(VBmodel,  mu_sigma_ratio)
                total_size_w.append(size_w)
                total_removed_w.append(removed_w)
                total_size_b.append(size_b)
                total_removed_b.append(removed_b)
                
        
        return  total_size_w, total_removed_w, total_size_b, total_removed_b
#    print (weights_to_remove_W.shape)

    
    """
    BAYESIAN NECESSARY FUNCTIONS
    """
    sample_posterior = GeneralVBModel.sample_posterior
    get_KL_divergence = GeneralVBModel.get_KL_divergence
    set_posterior_mean = GeneralVBModel.set_posterior_mean
    combine_losses = GeneralVBModel.combine_losses
    
    def save_VB_weights(self):
        """
        Function that saves only the VB weights of the model.
        """
        pretrained_dict = ...
        model_dict = self.state_dict()
        
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict) 
        # 3. load the new state dict
        self.load_state_dict(pretrained_dict)
Exemplo n.º 17
0
 def test_does_not_divide_by_zero_with_no_count(self, device: str):
     accuracy = BooleanAccuracy()
     self.assertAlmostEqual(accuracy.get_metric(), 0.0)
Exemplo n.º 18
0
class BertSpanPointerResolution(Model):
    """该模型同时预测mask位置以及span的起始位置"""
    def __init__(self,
                 vocab: Vocabulary,
                 model_name: str = None,
                 start_attention: Attention = None,
                 end_attention: Attention = None,
                 text_field_embedder: TextFieldEmbedder = None,
                 task_pretrained_file: str = None,
                 neg_sample_ratio: float = 0.0,
                 max_turn_len: int = 3,
                 start_token: str = "[CLS]",
                 end_token: str = "[SEP]",
                 index_name: str = "bert",
                 eps: float = 1e-8,
                 seed: int = 42,
                 loss_factor: float = 1.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: RegularizerApplicator = None):
        super().__init__(vocab, regularizer)
        if model_name is None and text_field_embedder is None:
            raise ValueError(
                f"`model_name` and `text_field_embedder` can't both equal to None."
            )
        # 单纯的resolution任务,只需要返回最后一层的embedding表征即可
        self._text_field_embedder = text_field_embedder or PretrainedChineseBertMismatchedEmbedder(
            model_name,
            return_all=False,
            output_hidden_states=False,
            max_turn_length=max_turn_len)

        seed_everything(seed)
        self._neg_sample_ratio = neg_sample_ratio
        self._start_token = start_token
        self._end_token = end_token
        self._index_name = index_name
        self._initializer = initializer

        linear_input_size = self._text_field_embedder.get_output_dim()
        # 使用attention的方法
        self.start_attention = start_attention or BilinearAttention(
            vector_dim=linear_input_size, matrix_dim=linear_input_size)
        self.end_attention = end_attention or BilinearAttention(
            vector_dim=linear_input_size, matrix_dim=linear_input_size)
        # mask的指标,主要考虑F-score,而且我们更加关注`1`的召回率
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._rewrite_em = RewriteEM(valid_keys="semr,nr_semr,re_semr")
        self._restore_score = RestorationScore(compute_restore_tokens=True)
        self._metrics = [
            TokenBasedBLEU(mode="1,2"),
            TokenBasedROUGE(mode="1r,2r")
        ]
        self._eps = eps
        self._loss_factor = loss_factor

        self._initializer(self.start_attention)
        self._initializer(self.end_attention)

        # 加载其他任务预训练的模型
        if task_pretrained_file is not None and os.path.isfile(
                task_pretrained_file):
            logger.info("loading related task pretrained weights...")
            self.load_state_dict(torch.load(task_pretrained_file),
                                 strict=False)

    def _calc_loss(self, span_start_logits: torch.Tensor,
                   span_end_logits: torch.Tensor, use_mask_label: torch.Tensor,
                   start_label: torch.Tensor, end_label: torch.Tensor,
                   best_spans: torch.Tensor):
        batch_size = start_label.size(0)
        # 常规loss
        loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-1)
        # --- 计算start和end标签对应的loss ---
        # 选择出mask_label等于1的位置对应的start和end的结果
        # [B_mask, ]
        span_start_label = start_label.masked_select(
            use_mask_label.to(dtype=torch.bool))
        span_end_label = end_label.masked_select(
            use_mask_label.to(dtype=torch.bool))
        # mask掉大部分为0的标签来计算准确率
        train_span_mask = (span_start_label != -1)

        # [B_mask, 2]
        answer_spans = torch.stack([span_start_label, span_end_label], dim=-1)
        self._span_accuracy(
            best_spans, answer_spans,
            train_span_mask.unsqueeze(-1).expand_as(best_spans))

        # -- 计算start_loss --
        start_losses = loss_fct(span_start_logits, span_start_label)
        # start_label_weight = self._calc_loss_weight(span_start_label)  # 计算标签的weight
        start_loss = torch.sum(start_losses) / batch_size
        # 对loss的值进行检查
        big_constant = min(torch.finfo(start_loss.dtype).max, 1e9)
        if torch.any(start_loss > big_constant):
            logger.critical("Start loss too high (%r)", start_loss)
            logger.critical("span_start_logits: %r", span_start_logits)
            logger.critical("span_start: %r", span_start_label)
            assert False

        # -- 计算end_loss --
        end_losses = loss_fct(span_end_logits, span_end_label)
        # end_label_weight = self._calc_loss_weight(span_end_label)   # 计算标签的weight
        end_loss = torch.sum(end_losses) / batch_size
        if torch.any(end_loss > big_constant):
            logger.critical("End loss too high (%r)", end_loss)
            logger.critical("span_end_logits: %r", span_end_logits)
            logger.critical("span_end: %r", span_end_label)
            assert False

        span_loss = (start_loss + end_loss) / 2

        self._span_start_accuracy(span_start_logits, span_start_label,
                                  train_span_mask)
        self._span_end_accuracy(span_end_logits, span_end_label,
                                train_span_mask)

        loss = span_loss
        return loss

    def _calc_loss_weight(self, label: torch.Tensor):
        label_mask = (label != 0).to(torch.float16)
        label_weight = label_mask * self._loss_factor + 1.0
        return label_weight

    def _get_rewrite_result(self, use_mask_label: torch.Tensor,
                            best_spans: torch.Tensor, query_lens: torch.Tensor,
                            context_lens: torch.Tensor,
                            metadata: List[Dict[str, Any]]):
        # 将两个标签转换成numpy类型
        # [B, query_len]
        use_mask_label = use_mask_label.detach().cpu().numpy()
        # [B_mask, 2]
        best_spans = best_spans.detach().cpu().numpy().tolist()

        predict_rewrite_results = []
        for cur_query_len, cur_context_len, cur_query_mask_labels, mdata in zip(
                query_lens, context_lens, use_mask_label, metadata):
            context_tokens = mdata['context_tokens']
            query_tokens = mdata['query_tokens']
            cur_rewrite_result = copy.deepcopy(query_tokens)
            already_insert_tokens = 0  # 记录已经插入的tokens的数量
            already_insert_min_start = cur_context_len  # 表示当前已经添加过的信息的最小的start
            already_insert_max_end = 0  # 表示当前已经添加过的信息的最大的end
            # 遍历当前mask的所有标签,如果标签为1,则计算对应的span_string
            for i in range(cur_query_len):
                cur_mask_label = cur_query_mask_labels[i]
                # 只有当预测的label为1时,才进行补充
                if cur_mask_label:
                    predict_start, predict_end = best_spans.pop(0)

                    # 如果都为0则继续
                    if predict_start == 0 and predict_end == 0:
                        continue
                    # 如果start大于长度,则继续
                    if predict_start >= cur_context_len:
                        continue
                    # 如果当前想要插入的信息,在之前已经插入过信息的内部,则不再插入
                    if predict_start >= already_insert_min_start and predict_end <= already_insert_max_end:
                        continue
                    # 对位置进行矫正
                    if predict_start < 0 or context_tokens[
                            predict_start] == self._start_token:
                        predict_start = 1

                    if predict_end >= cur_context_len:
                        predict_end = cur_context_len - 1

                    # 获取预测的span
                    predict_span_tokens = context_tokens[
                        predict_start:predict_end + 1]
                    # 更新已经插入的最小的start和最大的end
                    if predict_start < already_insert_min_start:
                        already_insert_min_start = predict_start
                    if predict_end > already_insert_max_end:
                        already_insert_max_end = predict_end
                    # 再对预测的span按照要求进行矫正,只取end_token之前的所有tokens
                    try:
                        index = predict_span_tokens.index(self._end_token)
                        predict_span_tokens = predict_span_tokens[:index]
                    except BaseException:
                        pass

                    # 获取当前span插入的位置
                    # 如果是要插入到当前位置后面,则需要+1
                    # 如果是要插入到当前位置前面,则不需要
                    cur_insert_index = i + already_insert_tokens
                    cur_rewrite_result = cur_rewrite_result[:cur_insert_index] + \
                        predict_span_tokens + cur_rewrite_result[cur_insert_index:]
                    # 记录插入的tokens的数量
                    already_insert_tokens += len(predict_span_tokens)

            cur_rewrite_result = cur_rewrite_result[:-1]
            # 不再以list of tokens的形式
            # 而是以string的形式去计算
            cur_rewrite_string = "".join(cur_rewrite_result)
            rewrite_tokens = mdata.get("rewrite_tokens", None)
            if rewrite_tokens is not None:
                rewrite_string = "".join(rewrite_tokens)
                # 去除[SEP]这个token
                query_string = "".join(query_tokens[:-1])
                self._rewrite_em(cur_rewrite_string, rewrite_string,
                                 query_string)
                # 额外增加的指标
                for metric in self._metrics:
                    metric(cur_rewrite_result, rewrite_tokens)
                # 获取restore_tokens并计算对应的指标
                restore_tokens = mdata.get("restore_tokens", None)
                self._restore_score(cur_rewrite_result,
                                    rewrite_tokens,
                                    queries=query_tokens[:-1],
                                    restore_tokens=restore_tokens)

            predict_rewrite_results.append("".join(cur_rewrite_result))
        return predict_rewrite_results

    @overrides
    def forward(self,
                context_ids: TextFieldTensors,
                query_ids: TextFieldTensors,
                context_lens: torch.Tensor,
                query_lens: torch.Tensor,
                mask_label: Optional[torch.Tensor] = None,
                start_label: Optional[torch.Tensor] = None,
                end_label: Optional[torch.Tensor] = None,
                metadata: List[Dict[str, Any]] = None):
        # concat the context and query to the encoder
        # get the indexers first
        indexers = context_ids.keys()
        dialogue_ids = {}

        # 获取context和query的长度
        context_len = torch.max(context_lens).item()
        query_len = torch.max(query_lens).item()

        # [B, _len]
        context_mask = get_mask_from_sequence_lengths(context_lens,
                                                      context_len)
        query_mask = get_mask_from_sequence_lengths(query_lens, query_len)
        for indexer in indexers:
            # get the various variables of context and query
            dialogue_ids[indexer] = {}
            for key in context_ids[indexer].keys():
                context = context_ids[indexer][key]
                query = query_ids[indexer][key]
                # concat the context and query in the length dim
                dialogue = torch.cat([context, query], dim=1)
                dialogue_ids[indexer][key] = dialogue

        # get the outputs of the dialogue
        if isinstance(self._text_field_embedder, TextFieldEmbedder):
            embedder_outputs = self._text_field_embedder(dialogue_ids)
        else:
            embedder_outputs = self._text_field_embedder(
                **dialogue_ids[self._index_name])

        # get the outputs of the query and context
        # [B, _len, embed_size]
        context_last_layer = embedder_outputs[:, :context_len].contiguous()
        query_last_layer = embedder_outputs[:, context_len:].contiguous()

        # ------- 计算span预测的结果 -------
        # 我们想要知道query中的每一个mask位置的token后面需要补充的内容
        # 也就是其对应的context中span的start和end的位置
        # 同理,将context扩展成 [b, query_len, context_len, embed_size]
        context_last_layer = context_last_layer.unsqueeze(dim=1).expand(
            -1, query_len, -1, -1).contiguous()
        # [b, query_len, context_len]
        context_expand_mask = context_mask.unsqueeze(dim=1).expand(
            -1, query_len, -1).contiguous()

        # 将上面3个部分拼接在一起
        # 这里表示query中所有的position
        span_embed_size = context_last_layer.size(-1)

        if self.training and self._neg_sample_ratio > 0.0:
            # 对mask中0的位置进行采样
            # [B*query_len, ]
            sample_mask_label = mask_label.view(-1)
            # 获取展开之后的长度以及需要采样的负样本的数量
            mask_length = sample_mask_label.size(0)
            mask_sum = int(
                torch.sum(sample_mask_label).item() * self._neg_sample_ratio)
            mask_sum = max(10, mask_sum)
            # 获取需要采样的负样本的索引
            neg_indexes = torch.randint(low=0,
                                        high=mask_length,
                                        size=(mask_sum, ))
            # 限制在长度范围内
            neg_indexes = neg_indexes[:mask_length]
            # 将负样本对应的位置mask置为1
            sample_mask_label[neg_indexes] = 1
            # [B, query_len]
            use_mask_label = sample_mask_label.view(
                -1, query_len).to(dtype=torch.bool)
            # 过滤掉query中pad的部分, [B, query_len]
            use_mask_label = use_mask_label & query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()
        else:
            use_mask_label = query_mask
            span_mask = use_mask_label.unsqueeze(dim=-1).unsqueeze(dim=-1)
            # 选择context部分可以使用的内容
            # [B_mask, context_len, span_embed_size]
            span_context_matrix = context_last_layer.masked_select(
                span_mask).view(-1, context_len, span_embed_size).contiguous()
            # 选择query部分可以使用的向量
            span_query_vector = query_last_layer.masked_select(
                span_mask.squeeze(dim=-1)).view(-1,
                                                span_embed_size).contiguous()
            span_context_mask = context_expand_mask.masked_select(
                span_mask.squeeze(dim=-1)).view(-1, context_len).contiguous()

        # 得到span属于每个位置的logits
        # [B_mask, context_len]
        span_start_probs = self.start_attention(span_query_vector,
                                                span_context_matrix,
                                                span_context_mask)
        span_end_probs = self.end_attention(span_query_vector,
                                            span_context_matrix,
                                            span_context_mask)

        span_start_logits = torch.log(span_start_probs + self._eps)
        span_end_logits = torch.log(span_end_probs + self._eps)

        # [B_mask, 2],最后一个维度第一个表示start的位置,第二个表示end的位置
        best_spans = get_best_span(span_start_logits, span_end_logits)
        # 计算得到每个best_span的分数
        best_span_scores = (
            torch.gather(span_start_logits, 1, best_spans[:, 0].unsqueeze(1)) +
            torch.gather(span_end_logits, 1, best_spans[:, 1].unsqueeze(1)))
        # [B_mask, ]
        best_span_scores = best_span_scores.squeeze(1)

        # 将重要的信息写入到输出中
        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_spans": best_spans,
            "best_span_scores": best_span_scores
        }

        # 如果存在标签,则使用标签计算loss
        if start_label is not None:
            loss = self._calc_loss(span_start_logits, span_end_logits,
                                   use_mask_label, start_label, end_label,
                                   best_spans)
            output_dict["loss"] = loss
        if metadata is not None:
            predict_rewrite_results = self._get_rewrite_result(
                use_mask_label, best_spans, query_lens, context_lens, metadata)
            output_dict['rewrite_results'] = predict_rewrite_results
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics = {}
        metrics["span_acc"] = self._span_accuracy.get_metric(reset)
        for metric in self._metrics:
            metrics.update(metric.get_metric(reset))
        metrics.update(self._rewrite_em.get_metric(reset))
        metrics.update(self._restore_score.get_metric(reset))
        return metrics

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        new_output_dict = {}
        new_output_dict["rewrite_results"] = output_dict["rewrite_results"]
        return new_output_dict
Exemplo n.º 19
0
class AttentionLSTM(Model):
    def __init__(self,
                 args,
                 word_embeddings: TextFieldEmbedder,
                 vocab: Vocabulary,
                 domain_info: bool = True) -> None:
        super().__init__(vocab)

        # parameters
        self.args = args
        self.word_embeddings = word_embeddings
        self.domain = domain_info

        # layers
        self.event_embedding = EventEmbedding(args, self.word_embeddings)
        self.event_type_embedding = EventTypeEmbedding(args,
                                                       self.word_embeddings)
        self.lstm = LSTM(input_size=self.args.embedding_size,
                         hidden_size=self.args.hidden_size)
        self.W_c = Linear(self.args.embedding_size,
                          self.args.hidden_size,
                          bias=False)
        self.W_e = Linear(self.args.hidden_size,
                          self.args.hidden_size,
                          bias=False)
        self.relu = ReLU()
        self.linear = Linear(self.args.hidden_size, self.args.embedding_size)
        self.attention = Attention(self.args.hidden_size, score_function='mlp')
        self.score = Score(self.args.embedding_size,
                           self.args.embedding_size,
                           threshold=self.args.threshold)

        # metrics
        self.accuracy = BooleanAccuracy()
        self.f1_score = F1Measure(positive_label=1)
        self.loss_function = BCELoss()

    @overrides
    def forward(self,
                trigger_0: Dict[str, torch.LongTensor],
                trigger_agent_0: Dict[str, torch.LongTensor],
                agent_attri_0: Dict[str, torch.LongTensor],
                trigger_object_0: Dict[str, torch.LongTensor],
                object_attri_0: Dict[str, torch.LongTensor],
                trigger_1: Dict[str, torch.LongTensor],
                trigger_agent_1: Dict[str, torch.LongTensor],
                agent_attri_1: Dict[str, torch.LongTensor],
                trigger_object_1: Dict[str, torch.LongTensor],
                object_attri_1: Dict[str, torch.LongTensor],
                trigger_2: Dict[str, torch.LongTensor],
                trigger_agent_2: Dict[str, torch.LongTensor],
                agent_attri_2: Dict[str, torch.LongTensor],
                trigger_object_2: Dict[str, torch.LongTensor],
                object_attri_2: Dict[str, torch.LongTensor],
                trigger_3: Dict[str, torch.LongTensor],
                trigger_agent_3: Dict[str, torch.LongTensor],
                agent_attri_3: Dict[str, torch.LongTensor],
                trigger_object_3: Dict[str, torch.LongTensor],
                object_attri_3: Dict[str, torch.LongTensor],
                trigger_4: Dict[str, torch.LongTensor],
                trigger_agent_4: Dict[str, torch.LongTensor],
                agent_attri_4: Dict[str, torch.LongTensor],
                trigger_object_4: Dict[str, torch.LongTensor],
                object_attri_4: Dict[str, torch.LongTensor],
                event_type: Dict[str, torch.LongTensor],
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # tri, e: [batch_size, 1, embedding_size]
        tri0, e0 = self.event_embedding(trigger_0, trigger_agent_0,
                                        trigger_object_0)
        tri1, e1 = self.event_embedding(trigger_1, trigger_agent_1,
                                        trigger_object_1)
        tri2, e2 = self.event_embedding(trigger_2, trigger_agent_2,
                                        trigger_object_2)
        tri3, e3 = self.event_embedding(trigger_3, trigger_agent_3,
                                        trigger_object_3)
        tri4, e4 = self.event_embedding(trigger_4, trigger_agent_4,
                                        trigger_object_4)
        event_type = self.event_type_embedding(event_type)

        # [batch_size, seq_Len, embedding_size]
        e = (torch.stack([e0, e1, e2, e3], dim=1)).squeeze(2)
        batch_size, seq_len, _ = e.size()

        # [batch_size, seq_len, embedding_size]
        event_types = (torch.stack(
            [event_type, event_type, event_type, event_type],
            dim=1)).squeeze(2)

        # [seq_Len, batch_size, embedding_size]
        e = e.view(seq_len, batch_size, -1)
        lstm_out, (hn, _) = self.lstm(e)
        # [batch_size, seq_len, hidden_size]
        lstm_out = lstm_out.view(batch_size, seq_len, -1)
        if self.domain:
            lstm_out = lstm_out + self.relu(
                self.W_c(event_types) + self.W_e(lstm_out))

        # [batch_size, 1, hidden_size]
        hn = hn.view(batch_size, 1, -1)

        # [batch_size, 1, hidden_size]
        out_atten, _ = self.attention(lstm_out, hn)
        # [batch_size, 1, embedding_size]
        out_atten = self.linear(out_atten)

        # [batch_size, 1] , [batch_size], [batch_size, label_size]
        score, logits, logits_f1 = self.score(out_atten, e4)

        output = {"logits": logits, "score": score}
        if label is not None:
            self.accuracy(logits, label)
            self.f1_score(logits_f1, label)
            output["loss"] = self.loss_function(score.squeeze(1),
                                                label.float())

        return output

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        accuracy = self.accuracy.get_metric(reset)
        precision, recall, f1_measure = self.f1_score.get_metric(reset)
        return {
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1_measure": f1_measure
        }
class BERT_QA(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 dropout: float = 0.0,
                 max_span_length: int = 30,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder
        self._max_span_length = max_span_length

        self.qa_outputs = torch.nn.Linear(
            self._text_field_embedder.get_output_dim(), 2)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._span_qa_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            context: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # the `context` is the concact of `question` and `passage`, so we just use `context`
        batch_size, num_of_passage_tokens = context['tokens'].size()

        # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of
        # start and end spans.
        embedded_passage = self._text_field_embedder(context)
        passage_length = embedded_passage.size(1)
        logits = self.qa_outputs(embedded_passage)
        start_logits, end_logits = logits.split(1, dim=-1)
        span_start_logits = start_logits.squeeze(-1)
        span_end_logits = end_logits.squeeze(-1)

        # Adding some masks with numerically stable values
        passage_mask = util.get_text_field_mask(passage).float()
        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            batch_size, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_start_probs = util.masked_softmax(span_start_logits,
                                               repeated_passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)
        span_end_probs = util.masked_softmax(span_end_logits,
                                             repeated_passage_mask)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict: Dict[str, Any] = {}

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end],
                                                     -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on span qa and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])

                passage_words = metadata[i]["paragraph_words"]
                answer_offset = metadata[i]["answer_offset"]
                tok_to_word_index = metadata[i]["tok_to_word_index"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_position = tok_to_word_index[predicted_span[0] -
                                                   answer_offset]
                end_position = tok_to_word_index[predicted_span[1] -
                                                 answer_offset]
                best_span_str = " ".join(
                    passage_words[start_position:end_position + 1])
                output_dict["best_span_str"].append(best_span_str)
                answer_text = metadata[i].get("answer_text", [])
                if answer_text:
                    answer_text = [answer_text]
                    self._span_qa_metrics(best_span_str, answer_text)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._span_qa_metrics.get_metric(reset)
        return {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "em": exact_match,
            "f1": f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = (torch.triu(
            torch.ones((passage_length, passage_length),
                       device=device)).log().unsqueeze(0))
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)
Exemplo n.º 21
0
class ModelSQUAD(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 attention_similarity_function: SimilarityFunction,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 feed_forward: FeedForward,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(ModelSQUAD, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = MatrixAttention(attention_similarity_function)
        self._residual_encoder = residual_encoder
        self._span_end_encoder = span_end_encoder
        self._span_start_encoder = span_start_encoder
        self._feed_forward = feed_forward

        encoding_dim = phrase_layer.get_output_dim()
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))
        self._no_answer_predictor = TimeDistributed(
            torch.nn.Linear(encoding_dim, 1))

        self._self_matrix_attention = MatrixAttention(
            attention_similarity_function)
        self._linear_layer = TimeDistributed(
            torch.nn.Linear(4 * encoding_dim, encoding_dim))
        self._residual_linear_layer = TimeDistributed(
            torch.nn.Linear(3 * encoding_dim, encoding_dim))

        self._w_x = torch.nn.Parameter(torch.Tensor(encoding_dim))
        self._w_y = torch.nn.Parameter(torch.Tensor(encoding_dim))
        self._w_xy = torch.nn.Parameter(torch.Tensor(encoding_dim))
        std = math.sqrt(6 / (encoding_dim * 3 + 1))
        self._w_x.data.uniform_(-std, std)
        self._w_y.data.uniform_(-std, std)
        self._w_xy.data.uniform_(-std, std)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.LongTensor = None,
            span_end: torch.LongTensor = None,
            spans=None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        # Shape: (batch_size, 4, passage_length, embedding_dim)
        embedded_passage = self._text_field_embedder(passage)

        (batch_size, q_length, embedding_dim) = embedded_question.size()
        passage_length = embedded_passage.size(2)

        # reshape: (batch_size*4, -1, embedding_dim)
        embedded_passage = embedded_passage.view(-1, passage_length,
                                                 embedding_dim)
        embedded_passage = self._highway_layer(embedded_passage)

        embedded_question = embedded_question.unsqueeze(0).expand(
            4, -1, -1, -1).contiguous().view(-1, q_length, embedding_dim)
        question_mask = util.get_text_field_mask(question).float()
        question_mask = question_mask.unsqueeze(0).expand(
            4, -1, -1).contiguous().view(-1, q_length)

        passage_mask = util.get_text_field_mask(passage, 1).float()
        passage_mask = passage_mask.view(-1, passage_length)

        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        cuda_device = encoded_question.get_device()

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        question_attended_passage = relu(
            self._linear_layer(final_merged_passage))

        # TODO: attach residual self-attention layer
        # Shape: (batch_size, passage_length, encoding_dim)
        residual_passage = self._dropout(
            self._residual_encoder(self._dropout(question_attended_passage),
                                   passage_lstm_mask))
        mask = passage_mask.resize(batch_size, passage_length,
                                   1) * passage_mask.resize(
                                       batch_size, 1, passage_length)
        self_mask = Variable(
            torch.eye(passage_length,
                      passage_length).cuda(cuda_device)).resize(
                          1, passage_length, passage_length)
        mask = mask * (1 - self_mask)
        # Shape: (batch_size, passage_length, passage_length)
        x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2)
        y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1)
        dot_similarity = torch.bmm(residual_passage * self._w_xy,
                                   residual_passage.transpose(1, 2))
        passage_self_similarity = dot_similarity + x_similarity + y_similarity
        #for i in range(passage_length):
        #    passage_self_similarity[:, i, i] = float('-Inf')
        # Shape: (batch_size, passage_length, passage_length)
        passage_self_attention = util.last_dim_softmax(passage_self_similarity,
                                                       mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_vectors = util.weighted_sum(residual_passage,
                                            passage_self_attention)
        # Shape: (batch_size, passage_length, encoding_dim * 3)
        merged_passage = torch.cat([
            residual_passage, passage_vectors,
            residual_passage * passage_vectors
        ],
                                   dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        self_attended_passage = relu(
            self._residual_linear_layer(merged_passage))

        # Shape: (batch_size, passage_length, encoding_dim)
        mixed_passage = question_attended_passage + self_attended_passage

        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_start = self._dropout(
            self._span_start_encoder(mixed_passage, passage_lstm_mask))
        span_start_logits = self._span_start_predictor(
            encoded_span_start).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, passage_length, encoding_dim * 2)
        concatenated_passage = torch.cat([mixed_passage, encoded_span_start],
                                         dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(concatenated_passage, passage_lstm_mask))
        span_end_logits = self._span_end_predictor(encoded_span_end).squeeze(
            -1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        # Shape: (batch_size, encoding_dim)
        v_1 = util.weighted_sum(encoded_span_start, span_start_probs)
        v_2 = util.weighted_sum(encoded_span_end, span_end_probs)

        no_span_logits = self._no_answer_predictor(
            self_attended_passage).squeeze(-1)
        no_span_probs = util.masked_softmax(no_span_logits, passage_mask)
        v_3 = util.weighted_sum(self_attended_passage, no_span_probs)
        # Shape: (batch_size, 1)
        z_score = self._feed_forward(torch.cat([v_1, v_2, v_3], dim=-1))
        # compute no-answer score

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }
        # create target tensor including no-answer label
        span_target = Variable(torch.ones(batch_size).long()).cuda(cuda_device)
        for b in range(batch_size):
            span_target[b].data[0] = span_start[
                b, 0].data[0] * passage_length + span_end[b, 0].data[0]
        span_target[span_target < 0] = passage_length**2

        # Shape: (batch_size, passage_length, passage_length)
        span_start_logits_tiled = span_start_logits.unsqueeze(1).expand(
            batch_size, passage_length, passage_length)
        span_end_logits_tiled = span_end_logits.unsqueeze(-1).expand(
            batch_size, passage_length, passage_length)
        span_logits = (span_start_logits_tiled + span_end_logits_tiled).view(
            batch_size, -1)
        answer_mask = torch.bmm(passage_mask.unsqueeze(-1),
                                passage_mask.unsqueeze(1)).view(
                                    batch_size, -1)
        no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device)
        combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1)
        all_logits = torch.cat([span_logits, z_score], dim=-1)
        loss = nll_loss(util.masked_log_softmax(all_logits, combined_mask),
                        span_target)
        output_dict["loss"] = loss

        # Shape(batch_size, max_answers, num_span)
        #    max_answers = spans.size(1)
        #    span_logits = torch.bmm(span_start_logits.unsqueeze(-1), span_end_logits.unsqueeze(1)).view(batch_size, -1)
        #    answer_mask = torch.bmm(passage_mask.unsqueeze(-1), passage_mask.unsqueeze(1)).view(batch_size, -1)
        #    no_answer_mask = Variable(torch.ones(batch_size, 1)).cuda(cuda_device)
        #    combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1)
        #    # Shape: (batch_size, passage_length**2 + 1)
        #    all_logits = torch.cat([span_logits, z_score], dim=-1)
        #    # Shape: (batch_size, max_answers)
        #    spans_combined = spans[:, :, 0] * passage_length + spans[:, :, 1]
        #    spans_combined[spans_combined < 0] = passage_length*passage_length
        #
        #    all_modified_logits = []
        #    for b in range(batch_size):
        #        idxs = Variable(torch.LongTensor(range(passage_length**2 + 1))).cuda(cuda_device)
        #        for i in range(max_answers):
        #            idxs[spans_combined[b, i].data[0]].data = idxs[spans_combined[b, 0].data[0]].data
        #        idxs[passage_length**2].data[0] = passage_length**2

        #        modified_logits = Variable(torch.zeros(all_logits.size(-1))).cuda(cuda_device)
        #        modified_logits.index_add_(0, idxs, all_logits[b])
        #        all_modified_logits.append(modified_logits)

        #    all_modified_logits = torch.stack(all_modified_logits, dim=0)
        #    loss = nll_loss(util.masked_log_softmax(all_modified_logits, combined_mask), spans_combined[:, 0])
        #    output_dict["loss"] = loss

        if span_start is not None:
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: Variable,
                      span_end_logits: Variable) -> Variable:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = Variable(span_start_logits.data.new().resize_(
            batch_size, 2).fill_(0)).long()

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'ModelSQUAD':
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(
            vocab, embedder_params)
        num_highway_layers = params.pop_int("num_highway_layers")
        phrase_layer = Seq2SeqEncoder.from_params(params.pop("phrase_layer"))
        similarity_function = SimilarityFunction.from_params(
            params.pop("similarity_function"))
        residual_encoder = Seq2SeqEncoder.from_params(
            params.pop("residual_encoder"))
        span_start_encoder = Seq2SeqEncoder.from_params(
            params.pop("span_start_encoder"))
        span_end_encoder = Seq2SeqEncoder.from_params(
            params.pop("span_end_encoder"))
        feed_forward = FeedForward.from_params(params.pop("feed_forward"))
        dropout = params.pop_float('dropout', 0.2)

        initializer = InitializerApplicator.from_params(
            params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(
            params.pop('regularizer', []))

        mask_lstms = params.pop_bool('mask_lstms', True)
        params.assert_empty(cls.__name__)
        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   num_highway_layers=num_highway_layers,
                   phrase_layer=phrase_layer,
                   attention_similarity_function=similarity_function,
                   residual_encoder=residual_encoder,
                   span_start_encoder=span_start_encoder,
                   span_end_encoder=span_end_encoder,
                   feed_forward=feed_forward,
                   dropout=dropout,
                   mask_lstms=mask_lstms,
                   initializer=initializer,
                   regularizer=regularizer)
Exemplo n.º 22
0
class CMVDiscriminator(FeedForward):
    def __init__(self,
                 input_dim: int,
                 num_layers: int,
                 hidden_dims: Union[int, Sequence[int]],
                 activations: Union[Activation, Sequence[Activation]],
                 dropout: Union[float, Sequence[float]] = 0.0,
                 gate_bias: float = -2) -> None:

        super(CMVDiscriminator,
              self).__init__(input_dim, num_layers, hidden_dims, activations,
                             dropout)

        if not isinstance(hidden_dims, list):
            hidden_dims = [hidden_dims] * (num_layers - 1)
        input_dims = hidden_dims[1:]

        gate_layers = [None]  #so we can zip this later
        for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims):
            gate_layer = torch.nn.Linear(layer_input_dim, layer_output_dim)
            gate_layer.bias.data.fill_(gate_bias)

            gate_layers.append(gate_layer)

        self._gate_layers = torch.nn.ModuleList(gate_layers)

        #feedforward requires an Activation so we just use the identity
        self._output_feedforward = FeedForward(hidden_dims[-1], 1, 1,
                                               lambda x: x)

        self._accuracy = BooleanAccuracy()

    def _get_hidden(self, output):
        layers = list(
            zip(self._linear_layers, self._activations, self._dropout,
                self._gate_layers))
        layer, activation, dropout, _ = layers[0]
        output = dropout(activation(layer(output)))

        for layer, activation, dropout, gate in layers[1:]:
            gate_output = torch.sigmoid(gate(output))
            new_output = dropout(activation(layer(output)))

            output = torch.add(torch.mul(gate_output, new_output),
                               torch.mul(1 - gate_output, output))

        return output

    def forward(self, real_output, fake_output=None):

        real_hidden = self._get_hidden(real_output)

        real_value = self._output_feedforward(real_hidden)
        labels = torch.ones(real_hidden.size(0))
        if torch.cuda.is_available() and real_value.is_cuda:
            idx = real_value.get_device()
            labels = labels.cuda(idx)

        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            real_value.view(-1), labels)

        predictions = torch.sigmoid(real_value) > 0.5

        if fake_output is not None:
            fake_hidden = self._get_hidden(fake_output)
            fake_value = self._output_feedforward(fake_hidden)
            fake_labels = torch.zeros(fake_hidden.size(0))

            if torch.cuda.is_available() and fake_value.is_cuda:
                idx = fake_value.get_device()
                fake_labels = fake_labels.cuda(idx)

            loss += torch.nn.functional.binary_cross_entropy_with_logits(
                fake_value.view(-1), fake_labels)

            predictions = torch.cat(
                [predictions, torch.sigmoid(fake_value) > 0.5])
            labels = torch.cat([labels, fake_labels])

        self._accuracy(predictions, labels.byte())

        return {'loss': loss, 'predictions': predictions, 'labels': labels}

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'accuracy': self._accuracy.get_metric(reset)}
Exemplo n.º 23
0
class Seq2SeqTask(SequenceGenerationTask):
    """Sequence-to-sequence Task"""

    def __init__(self, path, max_seq_len, max_targ_v_size, name, **kw):
        """ """
        super().__init__(name, **kw)
        self.scorer2 = BooleanAccuracy()
        self.scorers.append(self.scorer2)
        self.val_metric = "%s_accuracy" % self.name
        self.val_metric_decreases = False
        self.max_seq_len = max_seq_len
        self._label_namespace = self.name + "_tokens"
        self.max_targ_v_size = max_targ_v_size
        self.target_indexer = {"words": SingleIdTokenIndexer(namespace=self._label_namespace)}
        self.files_by_split = {
            split: os.path.join(path, "%s.tsv" % split) for split in ["train", "val", "test"]
        }

        # The following is necessary since word-level tasks (e.g., MT) haven't been tested, yet.
        if self._tokenizer_name != "SplitChars" and self._tokenizer_name != "dummy_tokenizer_name":
            raise NotImplementedError("For now, Seq2SeqTask only supports character-level tasks.")

    def load_data(self):
        # Data is exposed as iterable: no preloading
        pass

    def get_split_text(self, split: str):
        """
        Get split text as iterable of records.
        Split should be one of 'train', 'val', or 'test'.
        """
        return self.get_data_iter(self.files_by_split[split])

    def get_all_labels(self) -> List[str]:
        """ Build character vocabulary and return it as a list """
        token2freq = collections.Counter()
        for split in ["train", "val"]:
            for _, sequence in self.get_data_iter(self.files_by_split[split]):
                for token in sequence:
                    token2freq[token] += 1
        return [t for t, _ in token2freq.most_common(self.max_targ_v_size)]

    def get_data_iter(self, path):
        """ Load data """
        with codecs.open(path, "r", "utf-8", errors="ignore") as txt_fh:
            for row in txt_fh:
                row = row.strip().split("\t")
                if len(row) < 2 or not row[0] or not row[1]:
                    continue
                src_sent = tokenize_and_truncate(self._tokenizer_name, row[0], self.max_seq_len)
                tgt_sent = tokenize_and_truncate(self._tokenizer_name, row[2], self.max_seq_len)
                yield (src_sent, tgt_sent)

    def get_sentences(self) -> Iterable[Sequence[str]]:
        """ Yield sentences, used to compute vocabulary. """
        for split in self.files_by_split:
            # Don't use test set for vocab building.
            if split.startswith("test"):
                continue
            path = self.files_by_split[split]
            yield from self.get_data_iter(path)

    def count_examples(self):
        """ Compute here b/c we're streaming the sentences. """
        example_counts = {}
        for split, split_path in self.files_by_split.items():
            example_counts[split] = sum(
                1 for _ in codecs.open(split_path, "r", "utf-8", errors="ignore")
            )
        self.example_counts = example_counts

    def process_split(
        self, split, indexers, model_preprocessing_interface
    ) -> Iterable[Type[Instance]]:
        """ Process split text into a list of AllenNLP Instances. """

        def _make_instance(input_, target):
            d = {
                "inputs": sentence_to_text_field(
                    model_preprocessing_interface.boundary_token_fn(input_), indexers
                ),
                "targs": sentence_to_text_field(
                    model_preprocessing_interface.boundary_token_fn(target), self.target_indexer
                ),
            }
            return Instance(d)

        for sent1, sent2 in split:
            yield _make_instance(sent1, sent2)

    def get_metrics(self, reset=False):
        """Get metrics specific to the task"""
        avg_nll = self.scorer1.get_metric(reset)
        acc = self.scorer2.get_metric(reset)
        return {"perplexity": math.exp(avg_nll), "accuracy": acc}

    def update_metrics(self, logits, labels, tagmask=None):
        self.scorer2(logits.max(2)[1], labels, tagmask)
        return

    def get_prediction(self, voc_src, voc_trg, inputs, gold, output):
        tokenizer = get_tokenizer(self._tokenizer_name)

        input_string = tokenizer.detokenize([voc_src[token.item()] for token in inputs]).split(
            "<EOS>"
        )[0]
        gold_string = tokenizer.detokenize([voc_trg[token.item()] for token in gold]).split(
            "<EOS>"
        )[0]
        output_string = tokenizer.detokenize([voc_trg[token.item()] for token in output]).split(
            "<EOS>"
        )[0]

        return input_string, gold_string, output_string
Exemplo n.º 24
0
class BertMCQAModel(Model):
    """
    """
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model: str = None,
                 requires_grad: bool = True,
                 top_layer_only: bool = True,
                 bert_weights_model: str = None,
                 per_choice_loss: bool = False,
                 layer_freeze_regexes: List[str] = None,
                 regularizer: Optional[RegularizerApplicator] = None,
                 use_comparative_bert: bool = True,
                 use_bilinear_classifier: bool = False,
                 train_comparison_layer: bool = False,
                 number_of_choices_compared: int = 0,
                 comparison_layer_hidden_size: int = -1,
                 comparison_layer_use_relu: bool = True) -> None:
        super().__init__(vocab, regularizer)

        self._use_comparative_bert = use_comparative_bert
        self._use_bilinear_classifier = use_bilinear_classifier
        self._train_comparison_layer = train_comparison_layer
        if train_comparison_layer:
            assert number_of_choices_compared > 1
            self._num_choices = number_of_choices_compared
            self._comparison_layer_hidden_size = comparison_layer_hidden_size
            self._comparison_layer_use_relu = comparison_layer_use_relu

        # Bert weights and config
        if bert_weights_model:
            logging.info(f"Loading BERT weights model from {bert_weights_model}")
            bert_model_loaded = load_archive(bert_weights_model)
            self._bert_model = bert_model_loaded.model._bert_model
        else:
            self._bert_model = BertModel.from_pretrained(pretrained_model)

        for param in self._bert_model.parameters():
            param.requires_grad = requires_grad
        #for name, param in self._bert_model.named_parameters():
        #    grad = requires_grad
        #    if layer_freeze_regexes and grad:
        #        grad = not any([bool(re.search(r, name)) for r in layer_freeze_regexes])
        #    param.requires_grad = grad

        bert_config = self._bert_model.config
        self._output_dim = bert_config.hidden_size
        self._dropout = torch.nn.Dropout(bert_config.hidden_dropout_prob)
        self._per_choice_loss = per_choice_loss

        # Bert Classifier selector
        final_output_dim = 1
        if not use_comparative_bert:
            if bert_weights_model and hasattr(bert_model_loaded.model, "_classifier"):
                self._classifier = bert_model_loaded.model._classifier
            else:
                self._classifier = Linear(self._output_dim, final_output_dim)
        else:
            if use_bilinear_classifier:
                self._classifier = Bilinear(self._output_dim, self._output_dim, final_output_dim)
            else:
                self._classifier = Linear(self._output_dim * 2, final_output_dim)
        self._classifier.apply(self._bert_model.init_bert_weights)

        # Comparison layer setup
        if self._train_comparison_layer:
            number_of_pairs = self._num_choices * (self._num_choices - 1)
            if self._comparison_layer_hidden_size == -1:
                self._comparison_layer_hidden_size = number_of_pairs * number_of_pairs

            self._comparison_layer_1 = Linear(number_of_pairs, self._comparison_layer_hidden_size)
            if self._comparison_layer_use_relu:
                self._comparison_layer_1_activation = torch.nn.LeakyReLU()
            else:
                self._comparison_layer_1_activation = torch.nn.Tanh()
            self._comparison_layer_2 = Linear(self._comparison_layer_hidden_size, self._num_choices)
            self._comparison_layer_2_activation = torch.nn.Softmax()

        # Scalar mix, if necessary
        self._all_layers = not top_layer_only
        if self._all_layers:
            if bert_weights_model and hasattr(bert_model_loaded.model, "_scalar_mix") \
                    and bert_model_loaded.model._scalar_mix is not None:
                self._scalar_mix = bert_model_loaded.model._scalar_mix
            else:
                num_layers = bert_config.num_hidden_layers
                initial_scalar_parameters = num_layers * [0.0]
                initial_scalar_parameters[-1] = 5.0  # Starts with most mass on last layer
                self._scalar_mix = ScalarMix(bert_config.num_hidden_layers,
                                             initial_scalar_parameters=initial_scalar_parameters,
                                             do_layer_norm=False)
        else:
            self._scalar_mix = None

        # Accuracy and loss setup
        if self._train_comparison_layer:
            self._accuracy = CategoricalAccuracy()
            self._loss = torch.nn.CrossEntropyLoss()
        else:
            self._accuracy = BooleanAccuracy()
            self._loss = torch.nn.BCEWithLogitsLoss()
        self._debug = -1

    def _extract_last_token_pooled_output(self, encoded_layers, question_mask):
        """
        Extract the output vector for the last token in the sentence -
            similarly to how pooled_output is extracted for us when calling 'bert_model'.
        We need the question mask to find the last actual (non-padding) token
        :return:
        """

        if self._all_layers:
            encoded_layers = encoded_layers[-1]

        # A cool trick to extract the last "True" item in each row
        question_mask = question_mask.squeeze()
        # We already asserted this at batch_size == 1, but why not
        assert question_mask.dim() == 2
        shifted_matrix = question_mask.roll(-1, 1)
        shifted_matrix[:, -1] = 0
        last_item_indices = question_mask - shifted_matrix

        # TODO: This row, for some reason, didn't work as expected, but it is much better then the implementation that follows
        # last_token_tensor = encoded_layers[last_item_indices]

        num_pairs, token_number, hidden_size = encoded_layers.size()
        assert last_item_indices.size() == (num_pairs, token_number)
        # Don't worry, expand doesn't allocate new memory, it simply views the tensor differently
        expanded_last_item_indices = last_item_indices.unsqueeze(2).expand(num_pairs, token_number, hidden_size)
        last_token_tensor = encoded_layers.masked_select(expanded_last_item_indices.byte())
        last_token_tensor = last_token_tensor.reshape(num_pairs, hidden_size)

        pooled_output = self._bert_model.pooler.dense(last_token_tensor)
        pooled_output = self._bert_model.pooler.activation(pooled_output)

        return pooled_output

    def forward(self,
                question: Dict[str, torch.LongTensor],
                choice1_indexes: List[int] = None,
                choice2_indexes: List[int] = None,
                label: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = question['bert']

        # input_ids.size() == (batch_size, num_pairs, max_sentence_length)
        batch_size, num_pairs, _ = question['bert'].size()
        question_mask = (input_ids != 0).long()

        if self._train_comparison_layer:
            assert num_pairs == self._num_choices * (self._num_choices - 1)

        # Segment ids
        real_segment_ids = question['bert-type-ids'].clone()
        # Change the last 'SEP' to belong to the second answer (for symmetry)
        last_seps = (real_segment_ids.roll(-1) == 2) & (real_segment_ids == 1)
        real_segment_ids[last_seps] = 2
        # Update segment ids so that they are '1' for answers and '0' for the question
        real_segment_ids = (real_segment_ids == 0) | (real_segment_ids == 2)
        real_segment_ids = real_segment_ids.long()

        # TODO: How to extract last token pooled output if batch size != 1
        assert batch_size == 1

        # Run model
        encoded_layers, first_vectors_pooled_output = self._bert_model(input_ids=util.combine_initial_dims(input_ids),
                                            token_type_ids=util.combine_initial_dims(real_segment_ids),
                                            attention_mask=util.combine_initial_dims(question_mask),
                                            output_all_encoded_layers=self._all_layers)

        if self._use_comparative_bert:
            last_vectors_pooled_output = self._extract_last_token_pooled_output(encoded_layers, question_mask)
        else:
            last_vectors_pooled_output = None
        if self._all_layers:
            mixed_layer = self._scalar_mix(encoded_layers, question_mask)
            first_vectors_pooled_output = self._bert_model.pooler(mixed_layer)

        # Apply dropout
        first_vectors_pooled_output = self._dropout(first_vectors_pooled_output)
        if self._use_comparative_bert:
            last_vectors_pooled_output = self._dropout(last_vectors_pooled_output)

        # Classify
        if not self._use_comparative_bert:
            pair_label_logits = self._classifier(first_vectors_pooled_output)
        else:
            if self._use_bilinear_classifier:
                pair_label_logits = self._classifier(first_vectors_pooled_output, last_vectors_pooled_output)
            else:
                all_pooled_output = torch.cat((first_vectors_pooled_output, last_vectors_pooled_output), 1)
                pair_label_logits = self._classifier(all_pooled_output)

        pair_label_logits = pair_label_logits.view(-1, num_pairs)

        pair_label_probs = torch.sigmoid(pair_label_logits)

        output_dict = {}
        pair_label_probs_flat = pair_label_probs.squeeze(1)
        output_dict['pair_label_probs'] = pair_label_probs_flat.view(-1, num_pairs)
        output_dict['pair_label_logits'] = pair_label_logits
        output_dict['choice1_indexes'] = choice1_indexes
        output_dict['choice2_indexes'] = choice2_indexes

        if not self._train_comparison_layer:
            if label is not None:
                label = label.unsqueeze(1)
                label = label.expand(-1, num_pairs)
                relevant_pairs = (choice1_indexes == label) | (choice2_indexes == label)
                relevant_probs = pair_label_probs[relevant_pairs]
                choice1_is_the_label = (choice1_indexes == label)[relevant_pairs]
                # choice1_is_the_label = choice1_is_the_label.type_as(relevant_logits)

                loss = self._loss(relevant_probs, choice1_is_the_label.float())
                self._accuracy(relevant_probs >= 0.5, choice1_is_the_label)
                output_dict["loss"] = loss

            return output_dict
        else:
            choice_logits = self._comparison_layer_2(self._comparison_layer_1_activation(self._comparison_layer_1(
                pair_label_probs)))
            output_dict['choice_logits'] = choice_logits
            output_dict['choice_probs'] = torch.softmax(choice_logits, 1)
            output_dict['predicted_choice'] = torch.argmax(choice_logits, 1)

            if label is not None:
                loss = self._loss(choice_logits, label)
                self._accuracy(choice_logits, label)
                output_dict["loss"] = loss

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'EM': self._accuracy.get_metric(reset),
        }
Exemplo n.º 25
0
class MultiGranularityHierarchicalAttentionFusionNetworks(Model):
    def __init__(
            self,
            vocab: Vocabulary,
            elmo_embedder: TextFieldEmbedder,
            tokens_embedder: TextFieldEmbedder,
            features_embedder: TextFieldEmbedder,
            phrase_layer: Seq2SeqEncoder,
            projected_layer: Seq2SeqEncoder,
            contextual_passage: Seq2SeqEncoder,
            contextual_question: Seq2SeqEncoder,
            dropout: float = 0.2,
            regularizer: Optional[RegularizerApplicator] = None,
            initializer: InitializerApplicator = InitializerApplicator(),
    ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks,
              self).__init__(vocab, regularizer)
        self.elmo_embedder = elmo_embedder
        self.tokens_embedder = tokens_embedder
        self.features_embedder = features_embedder
        self._phrase_layer = phrase_layer
        self._encoding_dim = self._phrase_layer.get_output_dim()
        self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024,
                                               self._encoding_dim)
        self.fuse_p = FusionLayer(self._encoding_dim)
        self.fuse_q = FusionLayer(self._encoding_dim)
        self.fuse_s = FusionLayer(self._encoding_dim)
        self.projected_lstm = projected_layer
        self.contextual_layer_p = contextual_passage
        self.contextual_layer_q = contextual_question
        self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1)
        # self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._self_attention = BilinearMatrixAttention(self._encoding_dim,
                                                       self._encoding_dim)
        self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim,
                                               self._encoding_dim)
        self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim,
                                               self._encoding_dim)
        self.yesno_predictor = FeedForward(self._encoding_dim,
                                           self._encoding_dim, 3)
        self.relu = torch.nn.ReLU()

        self._max_span_length = 30

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self)

    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, max_qa_count, max_q_len, _ = question[
            'token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(yesno_list, 0).view(total_qa_count)

        # GloVe and simple cnn char embedding, embedding dim = 100 + 100 = 200
        word_emb_ques = self.tokens_embedder(
            question,
            num_wrapping_dims=1).reshape(total_qa_count, max_q_len,
                                         self.tokens_embedder.get_output_dim())
        word_emb_pass = self.tokens_embedder(passage)

        # Elmo embedding, embedding dim = 1024
        elmo_ques = self.elmo_embedder(question, num_wrapping_dims=1).reshape(
            total_qa_count, max_q_len, self.elmo_embedder.get_output_dim())
        elmo_pass = self.elmo_embedder(passage)

        # Passage features embedding, embedding dim = 20 + 20 = 40
        pass_feat = self.features_embedder(passage)

        # GloVe + cnn + Elmo
        embedded_question = self._variational_dropout(
            torch.cat([word_emb_ques, elmo_ques], dim=2))
        embedded_passage = self._variational_dropout(
            torch.cat([word_emb_pass, elmo_pass], dim=2))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question,
                                                 num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        # Concatenate Elmo after encoded passage
        encode_passage = self._phrase_layer(embedded_passage, passage_mask)
        projected_passage = self.relu(
            self.projected_layer(torch.cat([encode_passage, elmo_pass],
                                           dim=2)))

        # Concatenate Elmo after encoded question
        encode_question = self._phrase_layer(embedded_question, question_mask)
        projected_question = self.relu(
            self.projected_layer(torch.cat([encode_question, elmo_ques],
                                           dim=2)))

        encoded_passage = self._variational_dropout(projected_passage)
        repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
            1, max_qa_count, 1, 1)
        repeated_encoded_passage = repeated_encoded_passage.view(
            total_qa_count, passage_length, self._encoding_dim)
        repeated_pass_feat = (pass_feat.unsqueeze(1).repeat(
            1, max_qa_count, 1, 1)).view(total_qa_count, passage_length, 40)
        encoded_question = self._variational_dropout(projected_question)

        # total_qa_count * max_q_len * passage_length
        # cnt * m * n
        s = torch.bmm(encoded_question,
                      repeated_encoded_passage.transpose(2, 1))
        alpha = util.masked_softmax(s,
                                    question_mask.unsqueeze(2).expand(
                                        s.size()),
                                    dim=1)
        # cnt * n * h
        aligned_p = torch.bmm(alpha.transpose(2, 1), encoded_question)

        # cnt * m * n
        beta = util.masked_softmax(s,
                                   repeated_passage_mask.unsqueeze(1).expand(
                                       s.size()),
                                   dim=2)
        # cnt * m * h
        aligned_q = torch.bmm(beta, repeated_encoded_passage)

        fused_p = self.fuse_p(repeated_encoded_passage, aligned_p)
        fused_q = self.fuse_q(encoded_question, aligned_q)

        # add manual features here
        q_aware_p = self._variational_dropout(
            self.projected_lstm(
                torch.cat([fused_p, repeated_pass_feat], dim=2),
                repeated_passage_mask))

        # cnt * n * n
        # self_p = torch.bmm(q_aware_p, q_aware_p.transpose(2, 1))
        # self_p = self.bilinear_self_align(q_aware_p)
        self_p = self._self_attention(q_aware_p, q_aware_p)
        mask = repeated_passage_mask.reshape(
            total_qa_count, passage_length, 1) * repeated_passage_mask.reshape(
                total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              device=self_p.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        lamb = util.masked_softmax(self_p, mask, dim=2)
        # lamb = util.masked_softmax(self_p, repeated_passage_mask, dim=2)
        # cnt * n * h
        self_aligned_p = torch.bmm(lamb, q_aware_p)

        # cnt * n * h
        fused_self_p = self.fuse_s(q_aware_p, self_aligned_p)
        contextual_p = self._variational_dropout(
            self.contextual_layer_p(fused_self_p, repeated_passage_mask))
        # contextual_p = self.contextual_layer_p(fused_self_p, repeated_passage_mask)

        contextual_q = self._variational_dropout(
            self.contextual_layer_q(fused_q, question_mask))
        # contextual_q = self.contextual_layer_q(fused_q, question_mask)
        # cnt * m
        gamma = util.masked_softmax(
            self.linear_self_align(contextual_q).squeeze(2),
            question_mask,
            dim=1)
        # cnt * h
        weighted_q = torch.bmm(gamma.unsqueeze(1), contextual_q).squeeze(1)

        span_start_logits = self.bilinear_layer_s(weighted_q, contextual_p)
        span_end_logits = self.bilinear_layer_e(weighted_q, contextual_p)

        # cnt * n * 1  cnt * 1 * h
        span_yesno_logits = self.yesno_predictor(
            torch.bmm(span_end_logits.unsqueeze(2), weighted_q.unsqueeze(1)))
        # span_yesno_logits = self.yesno_predictor(contextual_p)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits,
                                                       span_end_logits,
                                                       span_yesno_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss for training

        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
                                                    repeated_passage_mask),
                            span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask),
                             span_end.view(-1),
                             ignore_index=-1)
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(torch.nn.functional.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                yesno_pred = predicted_span[2]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_query_id_list.append(iid)
                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad_eval.metric_max_over_ground_truths(
                                    squad_eval.f1_score, best_span_string,
                                    refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
        return output_dict

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[
            self.vocab.get_token_from_index(x, namespace="yesno_labels")
            for x in yn_list
        ] for yn_list in output_dict.pop("yesno")]
        output_dict['yesno'] = yesno_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'yesno': self._span_yesno_accuracy.get_metric(reset),
            'f1': self._official_f1.get_metric(reset),
        }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 3),
                                                     dtype=torch.long)
        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()

        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
        return best_word_span
Exemplo n.º 26
0
class BidafV4(Model):
    """
    MODIFICATION NOTE:
    This class is a modification of BiDAF. In here we try to see what happens to our results
    if we convert the question encoder into a simple term frequency (bag-of-words) encoder which
    disregards word order. By doing so we analyze whether BiDAF can learn to solve SQuAD without
    having to encode the question sequentially. It has been shown in previous work that BiDAF and
    other models trained on SQuAD do not focus on questions words as we would expect them to. For
    example, they will often focus

    ORIGINAL DOCSTRING:
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:

        super(BidafV4, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(),
                               4 * encoding_dim, "modeling layer input dim",
                               "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(),
                               phrase_layer.get_input_dim(),
                               "text field embedder output dim",
                               "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(),
                               4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim",
                               "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))

        # # v5:
        # # remember to set token embeddings in the CONFIG JSON
        # encoded_question = self._dropout(embedded_question)

        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX
        similarity_matrix = self._matrix_attention(encoded_passage,
                                                   encoded_question)

        # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY
        passage_question_attention = util.last_dim_softmax(
            similarity_matrix, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Our custom query2context
        q2c_attention = util.masked_softmax(similarity_matrix,
                                            question_mask,
                                            dim=1).transpose(-1, -2)
        q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention)

        # Now we try the various variants
        # v1:
        # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention)

        # v2:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1]))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2)

        # v3:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim)

        # v4:
        # Re-application of query2context attention
        new_similarity_matrix = self._matrix_attention(encoded_passage,
                                                       q2c_vecs)
        masked_similarity = util.replace_masked_values(
            new_similarity_matrix, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # ------- Original variant
        # # We replace masked values with something really negative here, so they don't affect the
        # # max below.
        # masked_similarity = util.replace_masked_values(similarity_matrix,
        #                                                question_mask.unsqueeze(1),
        #                                                -1e7)
        # # Shape: (batch_size, passage_length)
        # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # # Shape: (batch_size, passage_length)
        # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # # Shape: (batch_size, encoding_dim)
        # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # # Shape: (batch_size, passage_length, encoding_dim)
        # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
        #                                                                             passage_length,
        #                                                                             encoding_dim)

        # ------- END

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        # original beta combination function
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        # # v6:
        # final_merged_passage = torch.cat([tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v7:
        # final_merged_passage = torch.cat([passage_question_vectors],
        #                                  dim=-1)
        #
        # # v8:
        # final_merged_passage = torch.cat([passage_question_vectors,
        #                                   tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v9:
        # final_merged_passage = torch.cat([encoded_passage,
        #                                   passage_question_vectors,
        #                                   encoded_passage * passage_question_vectors],
        #                                  dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2),
                                                     dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
Exemplo n.º 27
0
class TweetJointly(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        transformer_model_name: str = "bert-base-uncased",
        feedforward: Optional[FeedForward] = None,
        smoothing: bool = False,
        smooth_alpha: float = 0.7,
        sentiment_task: bool = False,
        sentiment_task_weight: float = 1.0,
        sentiment_classification_with_label: bool = True,
        sentiment_seq2vec: Optional[Seq2VecEncoder] = None,
        candidate_span_task: bool = False,
        candidate_span_task_weight: float = 1.0,
        candidate_delay: int = 30000,
        candidate_span_num: int = 5,
        candidate_classification_layer_units: int = 128,
        candidate_span_extractor: Optional[SpanExtractor] = None,
        candidate_span_with_logits: bool = False,
        dropout: Optional[float] = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        if "BERTweet" not in transformer_model_name:
            self._text_field_embedder = BasicTextFieldEmbedder({
                "tokens":
                PretrainedTransformerEmbedder(transformer_model_name)
            })
        else:
            self._text_field_embedder = BasicTextFieldEmbedder(
                {"tokens": TweetBertEmbedder(transformer_model_name)})
        # span start & end task
        if feedforward is None:
            self._linear_layer = nn.Sequential(
                nn.Linear(self._text_field_embedder.get_output_dim(), 128),
                nn.ReLU(),
                nn.Linear(128, 2),
            )
        else:
            self._linear_layer = feedforward
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._jaccard = Jaccard()
        self._candidate_delay = candidate_delay
        self._delay = 0

        self._smoothing = smoothing
        self._smooth_alpha = smooth_alpha
        if smoothing:
            self._loss = nn.KLDivLoss(reduction="batchmean")
        else:
            self._loss = nn.CrossEntropyLoss()

        # sentiment task
        self._sentiment_task = sentiment_task
        if self._sentiment_task:
            self._sentiment_classification_accuracy = CategoricalAccuracy()
            self._sentiment_loss_log = LossLog()
            self.register_buffer("sentiment_task_weight",
                                 torch.tensor(sentiment_task_weight))
            self._sentiment_classification_with_label = (
                sentiment_classification_with_label)
            if sentiment_seq2vec is None:
                raise ConfigurationError(
                    "sentiment task is True, we need a sentiment seq2vec encoder"
                )
            else:
                self._sentiment_encoder = sentiment_seq2vec
                self._sentiment_linear = nn.Linear(
                    self._sentiment_encoder.get_output_dim(),
                    vocab.get_vocab_size("labels"),
                )

        # candidate span task
        self._candidate_span_task = candidate_span_task
        if candidate_span_task:
            assert candidate_span_num > 0
            assert candidate_span_task_weight > 0
            assert candidate_classification_layer_units > 0
            self._candidate_span_num = candidate_span_num
            self.register_buffer("candidate_span_task_weight",
                                 torch.tensor(candidate_span_task_weight))
            self._candidate_classification_layer_units = (
                candidate_classification_layer_units)
            self._span_classification_accuracy = CategoricalAccuracy()
            self._candidate_loss_log = LossLog()
            self._candidate_span_linear = nn.Linear(
                self._text_field_embedder.get_output_dim(),
                self._candidate_classification_layer_units,
            )

            if candidate_span_extractor is None:
                self._candidate_span_extractor = EndpointSpanExtractor(
                    input_dim=self._candidate_classification_layer_units)
            else:
                self._candidate_span_extractor = candidate_span_extractor

            if candidate_span_with_logits:
                self._candidate_with_logits = True
                self._candidate_span_vec_linear = nn.Linear(
                    self._candidate_span_extractor.get_output_dim() + 1, 1)
            else:
                self._candidate_with_logits = False
                self._candidate_span_vec_linear = nn.Linear(
                    self._candidate_span_extractor.get_output_dim(), 1)

            self._candidate_jaccard = Jaccard()

        if sentiment_task or candidate_span_task:
            self._base_loss_log = LossLog()
        else:
            self._base_loss_log = None

        if dropout is not None:
            self._dropout = nn.Dropout(dropout)
        else:
            self._dropout = None

    def forward(  # type: ignore
        self,
        text: Dict[str, Dict[str, torch.LongTensor]],
        sentiment: torch.IntTensor,
        text_with_sentiment: Dict[str, Dict[str, torch.LongTensor]],
        text_span: torch.IntTensor,
        selected_text_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # batch_size * text_length * hidden_dims
        embedded_question = self._text_field_embedder(text_with_sentiment)
        if self._dropout is not None:
            embedded_question = self._dropout(embedded_question)
        self._delay += int(embedded_question.size(0))
        # span start & span end task
        logits = self._linear_layer(embedded_question)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        span_start_logits = span_start_logits.squeeze(-1)
        span_end_logits = span_end_logits.squeeze(-1)

        possible_answer_mask = torch.zeros_like(
            util.get_token_ids_from_text_field_tensors(
                text_with_sentiment)).bool()
        for i, (start, end) in enumerate(text_span):
            possible_answer_mask[i, start:end + 1] = True

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       possible_answer_mask,
                                                       -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     possible_answer_mask,
                                                     -1e32)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
        best_spans = get_best_span(span_start_logits, span_end_logits)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_spans,
            "best_span_scores": best_span_scores,
        }

        loss = torch.tensor(0.0).to(embedded_question.device)
        # sentiment task
        if self._sentiment_task:
            if self._sentiment_classification_with_label:
                global_context_vec = self._sentiment_encoder(embedded_question)
            else:
                embedded_only_text = self._text_field_embedder(text)
                if self._dropout is not None:
                    embedded_only_text = self._dropout(embedded_only_text)
                global_context_vec = self._sentiment_encoder(
                    embedded_only_text)
            sentiment_logits = self._sentiment_linear(global_context_vec)
            sentiment_probs = torch.softmax(sentiment_logits, dim=-1)

            self._sentiment_classification_accuracy(sentiment_probs, sentiment)
            sentiment_loss = cross_entropy(sentiment_logits, sentiment)
            self._sentiment_loss_log(sentiment_loss)
            loss.add_(self.sentiment_task_weight * sentiment_loss)

            predict_sentiment_idx = sentiment_probs.argmax(dim=-1)
            sentiment_predicts = []
            for i in predict_sentiment_idx.tolist():
                sentiment_predicts.append(
                    self.vocab.get_token_from_index(i, "labels"))
            output_dict["sentiment_logits"] = sentiment_logits
            output_dict["sentiment_probs"] = sentiment_probs
            output_dict["sentiment_predicts"] = sentiment_predicts

        # span classification
        if self._candidate_span_task and (self._delay >=
                                          self._candidate_delay):
            # shape: (batch_size, passage_length, embedding_dim)
            text_features_for_candidate = self._candidate_span_linear(
                embedded_question)
            text_features_for_candidate = torch.relu(
                text_features_for_candidate)
            with torch.no_grad():
                # batch_size * candidate_num * 2
                candidate_span = get_candidate_span(span_start_probs,
                                                    span_end_probs,
                                                    self._candidate_span_num)
                candidate_span_list = candidate_span.tolist()
                output_dict["candidate_spans"] = candidate_span_list
            if selected_text_span is not None:
                candidate_span, candidate_span_label = self.candidate_span_with_labels(
                    candidate_span, selected_text_span)
            else:
                candidate_span_label = None
            # shape: (batch_size, candidate_num, span_extractor_output_dim)
            span_feature_vec = self._candidate_span_extractor(
                text_features_for_candidate, candidate_span)

            if self._candidate_with_logits:
                candidate_span_start_logits = torch.gather(
                    span_start_logits, 1, candidate_span[:, :, 0])
                candidate_span_end_logits = torch.gather(
                    span_end_logits, 1, candidate_span[:, :, 1])
                candidate_span_sum_logits = (candidate_span_start_logits +
                                             candidate_span_end_logits)
                span_feature_vec = torch.cat(
                    (span_feature_vec, candidate_span_sum_logits.unsqueeze(2)),
                    -1)
            # batch_size * candidate_num
            span_classification_logits = self._candidate_span_vec_linear(
                span_feature_vec).squeeze()
            span_classification_probs = torch.softmax(
                span_classification_logits, -1)
            output_dict[
                "span_classification_probs"] = span_classification_probs
            candidate_best_span_idx = span_classification_probs.argmax(dim=-1)
            view_idx = (
                candidate_best_span_idx +
                torch.arange(0, end=candidate_best_span_idx.shape[0]).to(
                    candidate_best_span_idx.device) * self._candidate_span_num)
            candidate_span_view = candidate_span.view(-1, 2)
            candidate_best_spans = candidate_span_view.index_select(
                0, view_idx)
            output_dict["candidate_best_spans"] = candidate_best_spans.tolist()

            if selected_text_span is not None:
                self._span_classification_accuracy(span_classification_probs,
                                                   candidate_span_label)
                candidate_span_loss = cross_entropy(span_classification_logits,
                                                    candidate_span_label)
                self._candidate_loss_log(candidate_span_loss)
                weighted_loss = self.candidate_span_task_weight * candidate_span_loss
                if candidate_span_loss > 1e2:
                    print(f"candidate loss: {candidate_span_loss}")
                    print(
                        f"span_classification_logits: {span_classification_logits}"
                    )
                    print(f"candidate_span_label: {candidate_span_label}")
                loss.add_(weighted_loss)

            candidate_best_spans = candidate_best_spans.detach().cpu().numpy()
            output_dict["best_candidate_span_str"] = []
            for metadata_entry, best_span in zip(metadata,
                                                 candidate_best_spans):
                text_with_sentiment_tokens = metadata_entry[
                    "text_with_sentiment_tokens"]
                predicted_start, predicted_end = tuple(best_span)
                if predicted_end >= len(text_with_sentiment_tokens):
                    predicted_end = len(text_with_sentiment_tokens) - 1
                best_span_string = self.span_tokens_to_text(
                    metadata_entry["text"],
                    text_with_sentiment_tokens,
                    predicted_start,
                    predicted_end,
                )
                output_dict["best_candidate_span_str"].append(best_span_string)
                answers = metadata_entry.get("selected_text", "")
                if len(answers) > 0:
                    self._candidate_jaccard(best_span_string, answers)

        # Compute the loss for training.
        if selected_text_span is not None:
            span_start = selected_text_span[:, 0]
            span_end = selected_text_span[:, 1]
            span_mask = span_start != -1
            self._span_accuracy(
                best_spans,
                selected_text_span,
                span_mask.unsqueeze(-1).expand_as(best_spans),
            )
            if not self._smoothing:
                start_loss = cross_entropy(span_start_logits,
                                           span_start,
                                           ignore_index=-1)
                if torch.any(start_loss > 1e9):
                    logger.critical("Start loss too high (%r)", start_loss)
                    logger.critical("span_start_logits: %r", span_start_logits)
                    logger.critical("span_start: %r", span_start)
                    logger.critical("text_with_sentiment: %r",
                                    text_with_sentiment)
                    assert False

                end_loss = cross_entropy(span_end_logits,
                                         span_end,
                                         ignore_index=-1)
                if torch.any(end_loss > 1e9):
                    logger.critical("End loss too high (%r)", end_loss)
                    logger.critical("span_end_logits: %r", span_end_logits)
                    logger.critical("span_end: %r", span_end)
                    assert False
            else:
                sequence_length = span_start_logits.size(1)
                device = span_start.device
                start_distance = get_sequence_distance_from_span_endpoint(
                    sequence_length, span_start)
                start_smooth_probs = torch.exp(
                    start_distance *
                    torch.log(torch.tensor(self._smooth_alpha).to(device)))
                start_smooth_probs = start_smooth_probs * possible_answer_mask
                start_smooth_probs = start_smooth_probs / start_smooth_probs.sum(
                    -1, keepdim=True)
                span_start_log_probs = span_start_logits - torch.log(
                    torch.exp(span_start_logits).sum(-1)).unsqueeze(-1)
                end_distance = get_sequence_distance_from_span_endpoint(
                    sequence_length, span_end)
                end_smooth_probs = torch.exp(
                    end_distance *
                    torch.log(torch.tensor(self._smooth_alpha).to(device)))
                end_smooth_probs = end_smooth_probs * possible_answer_mask
                end_smooth_probs = end_smooth_probs / end_smooth_probs.sum(
                    -1, keepdim=True)
                span_end_log_probs = span_end_logits - torch.log(
                    torch.exp(span_end_logits).sum(-1)).unsqueeze(-1)
                # print(end_smooth_probs)
                # print(start_smooth_probs)
                # print(span_end_log_probs)
                # print(span_start_log_probs)
                start_loss = self._loss(span_start_log_probs,
                                        start_smooth_probs)
                end_loss = self._loss(span_end_log_probs, end_smooth_probs)

            span_start_end_loss = (start_loss + end_loss) / 2
            if self._base_loss_log is not None:
                self._base_loss_log(span_start_end_loss)
            loss.add_(span_start_end_loss)
            self._span_start_accuracy(span_start_logits, span_start, span_mask)
            self._span_end_accuracy(span_end_logits, span_end, span_mask)

            output_dict["loss"] = loss

        # compute best span jaccard
        best_spans = best_spans.detach().cpu().numpy()
        output_dict["best_span_str"] = []

        for metadata_entry, best_span in zip(metadata, best_spans):
            text_with_sentiment_tokens = metadata_entry[
                "text_with_sentiment_tokens"]

            predicted_start, predicted_end = tuple(best_span)
            best_span_string = self.span_tokens_to_text(
                metadata_entry["text"],
                text_with_sentiment_tokens,
                predicted_start,
                predicted_end,
            )
            output_dict["best_span_str"].append(best_span_string)

            answers = metadata_entry.get("selected_text", "")
            if len(answers) > 0:
                self._jaccard(best_span_string, answers)

        return output_dict

    # @staticmethod
    # def candidate_span_with_labels(
    #     candidate_span: torch.Tensor, selected_text_span: torch.Tensor
    # ) -> Tuple[torch.Tensor, torch.Tensor]:
    #     correct_span_idx = (candidate_span == selected_text_span.unsqueeze(1)).prod(-1)
    #     candidate_span_adjust = torch.where(
    #         ~(correct_span_idx.unsqueeze(-1) == 1),
    #         candidate_span,
    #         selected_text_span.unsqueeze(1),
    #     )
    #     candidate_span_label = correct_span_idx.argmax(-1)
    #     return candidate_span_adjust, candidate_span_label

    @staticmethod
    def candidate_span_with_labels(
            candidate_span: torch.Tensor, selected_text_span: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        candidate_span_label = batch_span_jaccard(
            candidate_span, selected_text_span).max(-1).indices
        return candidate_span, candidate_span_label

    @staticmethod
    def get_candidate_span_mask(candidate_span: torch.Tensor,
                                passage_length: int) -> torch.Tensor:
        device = candidate_span.device
        batch_size, candidate_num = candidate_span.size()[:-1]
        candidate_span_mask = torch.zeros(batch_size, candidate_num,
                                          passage_length).to(device)
        for i in range(batch_size):
            for j in range(candidate_num):
                span_start, span_end = candidate_span[i][j]
                candidate_span_mask[i][j][span_start:span_end + 1] = 1
        return candidate_span_mask

    @staticmethod
    def span_tokens_to_text(source_text, tokens, span_start, span_end):
        text_with_sentiment_tokens = tokens
        predicted_start = span_start
        predicted_end = span_end

        while (predicted_start >= 0
               and text_with_sentiment_tokens[predicted_start].idx is None):
            predicted_start -= 1
        if predicted_start < 0:
            logger.warning(
                f"Could not map the token '{text_with_sentiment_tokens[span_start].text}' at index "
                f"'{span_start}' to an offset in the original text.")
            character_start = 0
        else:
            character_start = text_with_sentiment_tokens[predicted_start].idx

        while (predicted_end < len(text_with_sentiment_tokens)
               and text_with_sentiment_tokens[predicted_end].idx is None):
            predicted_end -= 1

        if predicted_end >= len(text_with_sentiment_tokens):
            print(text_with_sentiment_tokens)
            print(len(text_with_sentiment_tokens))
            print(span_end)
            print(predicted_end)
            logger.warning(
                f"Could not map the token '{text_with_sentiment_tokens[span_end].text}' at index "
                f"'{span_end}' to an offset in the original text.")
            character_end = len(source_text)
        else:
            end_token = text_with_sentiment_tokens[predicted_end]
            if end_token.idx == 0:
                character_end = (end_token.idx +
                                 len(sanitize_wordpiece(end_token.text)) + 1)
            else:
                character_end = end_token.idx + len(
                    sanitize_wordpiece(end_token.text))

        best_span_string = source_text[character_start:character_end].strip()
        return best_span_string

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        jaccard = self._jaccard.get_metric(reset)
        metrics = {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "jaccard": jaccard,
        }
        if self._candidate_span_task:
            metrics[
                "candidate_span_acc"] = self._span_classification_accuracy.get_metric(
                    reset)
            metrics["candidate_jaccard"] = self._candidate_jaccard.get_metric(
                reset)
            metrics["candidate_loss"] = self._candidate_loss_log.get_metric(
                reset)
        if self._sentiment_task:
            metrics[
                "sentiment_acc"] = self._sentiment_classification_accuracy.get_metric(
                    reset)
            metrics["sentiment_loss"] = self._sentiment_loss_log.get_metric(
                reset)
        if self._base_loss_log is not None:
            metrics["base_loss"] = self._base_loss_log.get_metric(reset)
        return metrics
Exemplo n.º 28
0
class EdgeProbingTask(Task):
    """ Generic class for fine-grained edge probing.

    Acts as a classifier, but with multiple targets for each input text.

    Targets are of the form (span1, span2, label), where span1 and span2 are
    half-open token intervals [i, j).

    Subclass this for each dataset, or use register_task with appropriate kw
    args.
    """
    @property
    def _tokenizer_suffix(self):
        """ Suffix to make sure we use the correct source files,
        based on the given tokenizer.
        """
        if self.tokenizer_name:
            return ".retokenized." + self.tokenizer_name
        else:
            return ""

    def tokenizer_is_supported(self, tokenizer_name):
        """ Check if the tokenizer is supported for this task. """
        # Assume all tokenizers supported; if retokenized data not found
        # for this particular task, we'll just crash on file loading.
        return True

    def __init__(
        self,
        path: str,
        max_seq_len: int,
        name: str,
        label_file: str = None,
        files_by_split: Dict[str, str] = None,
        single_sided: bool = False,
        **kw,
    ):
        """Construct an edge probing task.

        path, max_seq_len, and name are passed by the code in preprocess.py;
        remaining arguments should be provided by a subclass constructor or via
        @register_task.

        Args:
            path: data directory
            max_seq_len: maximum sequence length (currently ignored)
            name: task name
            label_file: relative path to labels file
            files_by_split: split name ('train', 'val', 'test') mapped to
                relative filenames (e.g. 'train': 'train.json')
            single_sided: if true, only use span1.
        """
        super().__init__(name, **kw)

        assert label_file is not None
        assert files_by_split is not None
        self._files_by_split = {
            split: os.path.join(path, fname) + self._tokenizer_suffix
            for split, fname in files_by_split.items()
        }
        self.path = path
        self.label_file = os.path.join(self.path, label_file)
        self.max_seq_len = max_seq_len
        self.single_sided = single_sided

        # Placeholders; see self.load_data()
        self._iters_by_split = None
        self.all_labels = None
        self.n_classes = None

        # see add_task_label_namespace in preprocess.py
        self._label_namespace = self.name + "_labels"

        # Scorers
        self.mcc_scorer = FastMatthews()
        self.acc_scorer = BooleanAccuracy()  # binary accuracy
        self.f1_scorer = F1Measure(positive_label=1)  # binary F1 overall
        self.val_metric = "%s_f1" % self.name  # TODO: switch to MCC?
        self.val_metric_decreases = False

    def get_all_labels(self) -> List[str]:
        return self.all_labels

    @classmethod
    def _stream_records(cls, filename):
        skip_ctr = 0
        total_ctr = 0
        for record in utils.load_json_data(filename):
            total_ctr += 1
            # Skip records with empty targets.
            # TODO(ian): don't do this if generating negatives!
            if not record.get("targets", None):
                skip_ctr += 1
                continue
            yield record
        log.info(
            "Read=%d, Skip=%d, Total=%d from %s",
            total_ctr - skip_ctr,
            skip_ctr,
            total_ctr,
            filename,
        )

    @staticmethod
    def merge_preds(record: Dict, preds: Dict) -> Dict:
        """ Merge predictions into record, in-place.

        List-valued predictions should align to targets,
        and are attached to the corresponding target entry.

        Non-list predictions are attached to the top-level record.
        """
        record["preds"] = {}
        for target in record["targets"]:
            target["preds"] = {}
        for key, val in preds.items():
            if isinstance(val, list):
                assert len(val) == len(record["targets"])
                for i, target in enumerate(record["targets"]):
                    target["preds"][key] = val[i]
            else:
                # non-list predictions, attach to top-level preds
                record["preds"][key] = val
        return record

    def load_data(self):
        self.all_labels = list(utils.load_lines(self.label_file))
        self.n_classes = len(self.all_labels)
        iters_by_split = collections.OrderedDict()
        for split, filename in self._files_by_split.items():
            #  # Lazy-load using RepeatableIterator.
            #  loader = functools.partial(utils.load_json_data,
            #                             filename=filename)
            #  iter = serialize.RepeatableIterator(loader)
            iter = list(self._stream_records(filename))
            iters_by_split[split] = iter
        self._iters_by_split = iters_by_split

    def get_split_text(self, split: str):
        """ Get split text as iterable of records.

        Split should be one of 'train', 'val', or 'test'.
        """
        return self._iters_by_split[split]

    @classmethod
    def get_num_examples(cls, split_text):
        """ Return number of examples in the result of get_split_text.

        Subclass can override this if data is not stored in column format.
        """
        return len(split_text)

    @classmethod
    def _make_span_field(cls, s, text_field, offset=1):
        return SpanField(s[0] + offset, s[1] - 1 + offset, text_field)

    def make_instance(self, record, idx, indexers,
                      model_preprocessing_interface) -> Type[Instance]:
        """Convert a single record to an AllenNLP Instance."""
        tokens = record["text"].split()  # already space-tokenized by Moses
        tokens = model_preprocessing_interface.boundary_token_fn(
            tokens)  # apply model-appropriate variants of [cls] and [sep].
        text_field = sentence_to_text_field(tokens, indexers)

        d = {}
        d["idx"] = MetadataField(idx)

        d["input1"] = text_field

        d["span1s"] = ListField([
            self._make_span_field(t["span1"], text_field, 1)
            for t in record["targets"]
        ])
        if not self.single_sided:
            d["span2s"] = ListField([
                self._make_span_field(t["span2"], text_field, 1)
                for t in record["targets"]
            ])

        # Always use multilabel targets, so be sure each label is a list.
        labels = [
            utils.wrap_singleton_string(t["label"]) for t in record["targets"]
        ]
        d["labels"] = ListField([
            MultiLabelField(label_set,
                            label_namespace=self._label_namespace,
                            skip_indexing=False) for label_set in labels
        ])
        return Instance(d)

    def process_split(
            self, records, indexers,
            model_preprocessing_interface) -> Iterable[Type[Instance]]:
        """ Process split text into a list of AllenNLP Instances. """
        def _map_fn(r, idx):
            return self.make_instance(r, idx, indexers,
                                      model_preprocessing_interface)

        return map(_map_fn, records, itertools.count())

    def get_sentences(self) -> Iterable[Sequence[str]]:
        """ Yield sentences, used to compute vocabulary. """
        for split, iter in self._iters_by_split.items():
            # Don't use test set for vocab building.
            if split.startswith("test"):
                continue
            for record in iter:
                yield record["text"].split()

    def get_metrics(self, reset=False):
        """Get metrics specific to the task"""
        metrics = {}
        metrics["mcc"] = self.mcc_scorer.get_metric(reset)
        metrics["acc"] = self.acc_scorer.get_metric(reset)
        precision, recall, f1 = self.f1_scorer.get_metric(reset)
        metrics["precision"] = precision
        metrics["recall"] = recall
        metrics["f1"] = f1
        return metrics
Exemplo n.º 29
0
class QaNetSemantic(Model):
    """
    This class implements Adams Wei Yu's `QANet Model <https://openreview.net/forum?id=B14TlG-RW>`_
    for machine reading comprehension published at ICLR 2018.

    The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet
    replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the
    modeling layer and output layer.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the passage-question attention.
    matrix_attention_layer : ``MatrixAttention``
        The matrix attention function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    dropout_prob : ``float``, optional (default=0.1)
        If greater than 0, we will apply dropout with this probability between layers.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 matrix_attention_layer: MatrixAttention,
                 modeling_layer: Seq2SeqEncoder,
                 dropout_prob: float = 0.1,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        text_embed_dim = text_field_embedder.get_output_dim()
        encoding_in_dim = phrase_layer.get_input_dim()
        encoding_out_dim = phrase_layer.get_output_dim()
        modeling_in_dim = modeling_layer.get_input_dim()
        modeling_out_dim = modeling_layer.get_output_dim()

        self._text_field_embedder = text_field_embedder

        self._embedding_proj_layer = torch.nn.Linear(text_embed_dim,
                                                     encoding_in_dim)
        self._highway_layer = Highway(encoding_in_dim, num_highway_layers)

        self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim,
                                                    encoding_in_dim)
        self._phrase_layer = phrase_layer

        self._matrix_attention = matrix_attention_layer

        self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4,
                                                    modeling_in_dim)
        self._modeling_layer = modeling_layer

        self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)
        self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._metrics = SquadEmAndF1()
        self._dropout = torch.nn.Dropout(
            p=dropout_prob) if dropout_prob > 0 else lambda x: x

        # evaluation

        # BLEU
        self._bleu_score_types_to_use = ["BLEU1", "BLEU2", "BLEU3", "BLEU4"]
        self._bleu_scores = {
            x: Average()
            for x in self._bleu_score_types_to_use
        }

        # ROUGE using pyrouge
        self._rouge_score_types_to_use = ['rouge-n', 'rouge-l', 'rouge-w']

        # if we have rouge-n as metric we actualy get n scores like rouge-1, rouge-2, .., rouge-n
        max_rouge_n = 4
        rouge_n_metrics = []
        if "rouge-n" in self._rouge_score_types_to_use:
            rouge_n_metrics = [
                "rouge-{0}".format(x) for x in range(1, max_rouge_n + 1)
            ]

        rouge_scores_names = rouge_n_metrics + [
            y for y in self._rouge_score_types_to_use if y != 'rouge-n'
        ]
        self._rouge_scores = {x: Average() for x in rouge_scores_names}
        self._rouge_evaluator = rouge.Rouge(
            metrics=self._rouge_score_types_to_use,
            max_n=max_rouge_n,
            limit_length=True,
            length_limit=100,
            length_limit_type='words',
            apply_avg=False,
            apply_best=False,
            alpha=0.5,  # Default F1_score
            weight_factor=1.2,
            stemming=True)

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            passage_sem_views_q: torch.IntTensor = None,
            passage_sem_views_k: torch.IntTensor = None,
            question_sem_views_q: torch.IntTensor = None,
            question_sem_views_k: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        passage_sem_views_q : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Query (Q)
        passage_sem_views_k : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Key (K)
        question_sem_views_q : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Query (Q)
        question_sem_views_k : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Key (K)

        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(
            self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(
            self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(
            embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(
            embedded_passage)

        if isinstance(self._phrase_layer, QaNetSemanticEncoder):
            passage_sem_views_q = passage_sem_views_q.long()
            passage_sem_views_k = passage_sem_views_k.long()
            question_sem_views_q = question_sem_views_q.long()
            question_sem_views_k = question_sem_views_k.long()

            encoded_passage = self._dropout(
                self._phrase_layer(projected_embedded_passage,
                                   passage_sem_views_q, passage_sem_views_k,
                                   passage_mask))
            encoded_question = self._dropout(
                self._phrase_layer(projected_embedded_question,
                                   question_sem_views_q, question_sem_views_k,
                                   question_mask))
        else:
            encoded_passage = self._dropout(
                self._phrase_layer(projected_embedded_passage, passage_mask))
            encoded_question = self._dropout(
                self._phrase_layer(projected_embedded_question, question_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True)

        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2),
            passage_mask,
            memory_efficient=True)
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention,
                                             question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage,
                                                    attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat([
                encoded_passage, passage_question_vectors,
                encoded_passage * passage_question_vectors,
                encoded_passage * passage_passage_vectors
            ],
                      dim=-1))

        modeled_passage_list = [
            self._modeling_proj_layer(merged_passage_attention_vectors)
        ]

        for _ in range(3):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat(
            [modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat(
            [modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            try:
                loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, passage_mask),
                    span_start.squeeze(-1))
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(-1))
                loss += nll_loss(
                    util.masked_log_softmax(span_end_logits, passage_mask),
                    span_end.squeeze(-1))
                self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
                self._span_accuracy(best_span,
                                    torch.stack([span_start, span_end], -1))
                output_dict["loss"] = loss
            except Exception as e:
                logging.exception(e)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []

            all_reference_answers_text = []
            all_best_spans = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())

                # offsets = metadata[i]['token_offsets']
                # start_offset = offsets[predicted_span[0]][0]
                # end_offset = offsets[predicted_span[1]][1]

                start_span = predicted_span[0]
                end_span = predicted_span[1]
                best_span_tokens = metadata[i]['passage_tokens'][
                    start_span:end_span + 1]
                best_span_string = " ".join(best_span_tokens)
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
                    all_best_spans.append(best_span_string)
                    all_reference_answers_text.append(answer_texts)

            if not self.training:
                self.calculate_rouge(all_best_spans,
                                     all_reference_answers_text)

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def calculate_rouge(self, predictions, references):
        # calculate rouge
        references_text = references
        predictions_text = predictions

        metrics_with_per_item_scores = self._rouge_evaluator.get_scores(
            predictions_text, references_text)
        for metric, results in sorted(metrics_with_per_item_scores.items(),
                                      key=lambda x: x[0]):
            for hypothesis_id, results_per_ref in enumerate(results):
                # we report the max f-score of the two answers
                curr_item_rouge_f = max(results_per_ref['f'])
                self._rouge_scores[metric](curr_item_rouge_f)

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._metrics.get_metric(reset)
        metrics = {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

        # # report bleu scores
        # for k,v in self._bleu_scores.items():
        #     metrics[k] = v.get_metric(reset)

        for k, v in self._rouge_scores.items():
            metrics[k] = v.get_metric(reset)

        return metrics
Exemplo n.º 30
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(
            Highway(text_field_embedder.get_output_dim(), num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(
            torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(
            torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(),
                               4 * encoding_dim, "modeling layer input dim",
                               "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(),
                               phrase_layer.get_input_dim(),
                               "text field embedder output dim",
                               "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(),
                               4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim",
                               "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['best_span_indices'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                output_dict['best_span_indices'].append(
                    [start_offset, end_offset])
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor) -> torch.Tensor:
        # We call the inputs "logits" - they could either be unnormalized logits or normalized log
        # probabilities.  A log_softmax operation is a constant shifting of the entire logit
        # vector, so taking an argmax over either one gives the same result.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        device = span_start_logits.device
        # (batch_size, passage_length, passage_length)
        span_log_probs = span_start_logits.unsqueeze(
            2) + span_end_logits.unsqueeze(1)
        # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
        # the span ends before it starts.
        span_log_mask = torch.triu(
            torch.ones((passage_length, passage_length),
                       device=device)).log().unsqueeze(0)
        valid_span_log_probs = span_log_probs + span_log_mask

        # Here we take the span matrix and flatten it, then find the best span using argmax.  We
        # can recover the start and end indices from this flattened list using simple modular
        # arithmetic.
        # (batch_size, passage_length * passage_length)
        best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
        span_start_indices = best_spans // passage_length
        span_end_indices = best_spans % passage_length
        return torch.stack([span_start_indices, span_end_indices], dim=-1)
class EvidenceExtraction(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 question_encoder: Seq2SeqEncoder,
                 passage_encoder: Seq2SeqEncoder,
                 r: float = 0.8,
                 dropout: float = 0.1,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(EvidenceExtraction, self).__init__(vocab, regularizer)

        self._embedder = embedder

        self._question_encoder = question_encoder
        self._passage_encoder = passage_encoder

        # size: 2H
        encoding_dim = question_encoder.get_output_dim()

        self._gru_cell = nn.GRUCell(2 * encoding_dim, encoding_dim)

        self._gate = nn.Linear(2 * encoding_dim, 2 * encoding_dim)

        self._match_layer_1 = nn.Linear(2 * encoding_dim, encoding_dim)
        self._match_layer_2 = nn.Linear(encoding_dim, 1)

        self._question_attention_for_passage = Attention(
            NonlinearSimilarity(encoding_dim))
        self._question_attention_for_question = Attention(
            NonlinearSimilarity(encoding_dim))
        self._passage_attention_for_answer = Attention(
            NonlinearSimilarity(encoding_dim), normalize=False)
        self._passage_attention_for_ranking = Attention(
            NonlinearSimilarity(encoding_dim))

        self._passage_self_attention = Attention(
            NonlinearSimilarity(encoding_dim))
        self._self_gru_cell = nn.GRUCell(2 * encoding_dim, encoding_dim)
        self._self_gate = nn.Linear(2 * encoding_dim, encoding_dim)

        self._answer_net = nn.GRUCell(encoding_dim, encoding_dim)

        self._v_r_Q = nn.Parameter(torch.rand(encoding_dim))
        self._r = r

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x

        initializer(self)

    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            #passages_length: torch.LongTensor = None,
            #correct_passage: torch.LongTensor = None,
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata=None) -> Dict[str, torch.Tensor]:

        # shape: B x Tq x E
        embedded_question = self._embedder(question)
        embedded_passage = self._embedder(passage)

        batch_size = embedded_question.size(0)
        total_passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        # shape: B x T x 2H
        encoded_question = self._dropout(
            self._question_encoder(embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._passage_encoder(embedded_passage, passage_mask))
        passage_mask = passage_mask.float()
        question_mask = question_mask.float()

        encoding_dim = encoded_question.size(-1)

        # shape: B x 2H
        if encoded_passage.is_cuda:
            cuda_device = encoded_passage.get_device()
            gru_hidden = Variable(
                torch.zeros(batch_size, encoding_dim).cuda(cuda_device))
        else:
            gru_hidden = Variable(torch.zeros(batch_size, encoding_dim))

        question_awared_passage = []
        for timestep in range(total_passage_length):
            # shape: B x Tq = attention(B x 2H, B x Tq x 2H)
            attn_weights = self._question_attention_for_passage(
                encoded_passage[:, timestep, :], encoded_question,
                question_mask)
            # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq)
            attended_question = util.weighted_sum(encoded_question,
                                                  attn_weights)
            # shape: B x 4H
            passage_question_combined = torch.cat(
                [encoded_passage[:, timestep, :], attended_question], dim=-1)
            # shape: B x 4H
            gate = F.sigmoid(self._gate(passage_question_combined))
            gru_input = gate * passage_question_combined
            # shape: B x 2H
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            question_awared_passage.append(gru_hidden)

        # shape: B x T x 2H
        # question aware passage representation v_P
        question_awared_passage = torch.stack(question_awared_passage, dim=1)

        self_attended_passage = []
        for timestep in range(total_passage_length):
            attn_weights = self._passage_self_attention(
                question_awared_passage[:, timestep, :],
                question_awared_passage, passage_mask)
            attended_passage = util.weighted_sum(question_awared_passage,
                                                 attn_weights)
            input_combined = torch.cat(
                [question_awared_passage[:, timestep, :], attended_passage],
                dim=-1)
            gate = F.sigmoid(self._self_gate(input_combined))
            gru_input = gate * input_combined
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            self_attended_passage.append(gru_hidden)

        self_attended_passage = torch.stack(self_attended_passage, dim=1)

        # compute question vector r_Q
        # shape: B x T = attention(B x 2H, B x T x 2H)
        v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim)
        attn_weights = self._question_attention_for_question(
            v_r_Q_tiled, encoded_question, question_mask)
        # shape: B x 2H
        r_Q = util.weighted_sum(encoded_question, attn_weights)
        # shape: B x T = attention(B x 2H, B x T x 2H)
        span_start_logits = self._passage_attention_for_answer(
            r_Q, self_attended_passage, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_start_log_probs = util.masked_log_softmax(span_start_logits,
                                                       passage_mask)
        # shape: B x 2H
        c_t = util.weighted_sum(self_attended_passage, span_start_probs)
        # shape: B x 2H
        h_1 = self._dropout(self._answer_net(c_t, r_Q))

        span_end_logits = self._passage_attention_for_answer(
            h_1, self_attended_passage, passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_end_log_probs = util.masked_log_softmax(span_end_logits,
                                                     passage_mask)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        #num_passages = passages_length.size(1)
        #acc = Variable(torch.zeros(batch_size, num_passages + 1)).cuda(cuda_device).long()

        #acc[:, 1:num_passages+1] = torch.cumsum(passages_length, dim=1)

        #g_batch = []
        #for b in range(batch_size):
        #    g = []
        #    for i in range(num_passages):
        #        if acc[b, i+1].data[0] > acc[b, i].data[0]:
        #            attn_weights = self._passage_attention_for_ranking(r_Q[b:b+1], question_awared_passage[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0], :], passage_mask[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0]])
        #            r_P = util.weighted_sum(question_awared_passage[b:b+1, acc[b, i].data[0]:acc[b, i+1].data[0], :], attn_weights)
        #            question_passage_combined = torch.cat([r_Q[b:b+1], r_P], dim=-1)
        #            gi = self._dropout(self._match_layer_2(F.tanh(self._dropout(self._match_layer_1(question_passage_combined)))))
        #            g.append(gi)
        #        else:
        #            g.append(Variable(torch.zeros(1, 1)).cuda(cuda_device))
        #    g = torch.cat(g, dim=1)
        #    g_batch.append(g)

        #t2 = time.time()
        #g = torch.cat(g_batch, dim=0)
        output_dict = {}
        if span_start is not None:
            AP_loss = F.nll_loss(span_start_log_probs, span_start.squeeze(-1)) +\
                F.nll_loss(span_end_log_probs, span_end.squeeze(-1))
            #PR_loss = F.nll_loss(passage_log_probs, correct_passage.squeeze(-1))
            #loss = self._r * AP_loss + self._r * PR_loss
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict['loss'] = AP_loss

        _, max_start = torch.max(span_start_probs, dim=1)
        _, max_end = torch.max(span_end_probs, dim=1)
        #t3 = time.time()
        output_dict['span_start_idx'] = max_start
        output_dict['span_end_idx'] = max_end
        #t4 = time.time()
        #global ITE
        #ITE += 1
        #if (ITE % 100 == 0):
        #    print(" gold %i:%i|predicted %i:%i" %(span_start.squeeze(-1)[0], span_end.squeeze(-1)[0], max_start.data[0], max_end.data[0]))
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens

        #t5 = time.time()
        #print("Total: %.5f" % (t5-t0))
        #print("Batch processing 1: %.5f" % (t2-t1))
        #print("Batch processing 2: %.5f" % (t4-t3))
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score
        }

    @staticmethod
    def get_best_span(span_start_logits: Variable,
                      span_end_logits: Variable) -> Variable:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = Variable(span_start_logits.data.new().resize_(
            batch_size, 2).fill_(0)).long()

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]
                val2 = span_end_logits[b, j]
                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span

    @classmethod
    def from_params(cls, vocab: Vocabulary,
                    params: Params) -> 'EvidenceExtraction':
        embedder_params = params.pop("text_field_embedder")
        embedder = TextFieldEmbedder.from_params(vocab, embedder_params)
        question_encoder = Seq2SeqEncoder.from_params(
            params.pop("question_encoder"))
        passage_encoder = Seq2SeqEncoder.from_params(
            params.pop("passage_encoder"))
        dropout = params.pop_float('dropout', 0.1)
        r = params.pop_float('r', 0.8)
        #cuda = params.pop_int('cuda', 0)

        initializer = InitializerApplicator.from_params(
            params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(
            params.pop('regularizer', []))

        return cls(
            vocab=vocab,
            embedder=embedder,
            question_encoder=question_encoder,
            passage_encoder=passage_encoder,
            r=r,
            dropout=dropout,
            #cuda=cuda,
            initializer=initializer,
            regularizer=regularizer)
Exemplo n.º 32
0
class DialogQA(Model):
    """
    This class implements modified version of BiDAF
    (with self attention and residual layer, from Clark and Gardner ACL 17 paper) model as used in
    Question Answering in Context (EMNLP 2018) paper [https://arxiv.org/pdf/1808.07036.pdf].

    In this set-up, a single instance is a dialog, list of question answer pairs.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    span_start_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span end predictions into the passage state.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    num_context_answers : ``int``, optional (default=0)
        If greater than 0, the model will consider previous question answering context.
    max_span_length: ``int``, optional (default=0)
        Maximum token length of the output span.
    max_turn_length: ``int``, optional (default=12)
        Maximum length of an interaction.
    """

    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 num_context_answers: int = 0,
                 marker_embedding_dim: int = 10,
                 max_span_length: int = 30,
                 max_turn_length: int = 12) -> None:
        super().__init__(vocab)
        self._num_context_answers = num_context_answers
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._marker_embedding_dim = marker_embedding_dim
        self._encoding_dim = phrase_layer.get_output_dim()

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._merge_atten = TimeDistributed(torch.nn.Linear(self._encoding_dim * 4, self._encoding_dim))

        self._residual_encoder = residual_encoder

        if num_context_answers > 0:
            self._question_num_marker = torch.nn.Embedding(max_turn_length,
                                                           marker_embedding_dim * num_context_answers)
            self._prev_ans_marker = torch.nn.Embedding((num_context_answers * 4) + 1, marker_embedding_dim)

        self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        self._followup_lin = torch.nn.Linear(self._encoding_dim, 3)
        self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3,
                                                                     self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_yesno_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 3))
        self._span_followup_predictor = TimeDistributed(self._followup_lin)

        check_dimensions_match(phrase_layer.get_input_dim(),
                               text_field_embedder.get_output_dim() +
                               marker_embedding_dim * num_context_answers,
                               "phrase layer input dim",
                               "embedding dim + marker dim * num context answers")

        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_followup_accuracy = CategoricalAccuracy()

        self._span_gt_yesno_accuracy = CategoricalAccuracy()
        self._span_gt_followup_accuracy = CategoricalAccuracy()

        self._span_accuracy = BooleanAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [[self.vocab.get_token_from_index(x, namespace="yesno_labels") for x in yn_list] \
                      for yn_list in output_dict.pop("yesno")]
        followup_tags = [[self.vocab.get_token_from_index(x, namespace="followup_labels") for x in followup_list] \
                         for followup_list in output_dict.pop("followup")]
        output_dict['yesno'] = yesno_tags
        output_dict['followup'] = followup_tags
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'yesno': self._span_yesno_accuracy.get_metric(reset),
                'followup': self._span_followup_accuracy.get_metric(reset),
                'f1': self._official_f1.get_metric(reset), }

    @staticmethod
    def _get_best_span_yesno_followup(span_start_logits: torch.Tensor,
                                      span_end_logits: torch.Tensor,
                                      span_yesno_logits: torch.Tensor,
                                      span_followup_logits: torch.Tensor,
                                      max_span_length: int) -> torch.Tensor:
        # Returns the index of highest-scoring span that is not longer than 30 tokens, as well as
        # yesno prediction bit and followup prediction bit from the predicted span end token.
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size

        best_word_span = span_start_logits.new_zeros((batch_size, 4), dtype=torch.long)

        span_start_logits = span_start_logits.data.cpu().numpy()
        span_end_logits = span_end_logits.data.cpu().numpy()
        span_yesno_logits = span_yesno_logits.data.cpu().numpy()
        span_followup_logits = span_followup_logits.data.cpu().numpy()
        for b_i in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b_i, span_start_argmax[b_i]]
                if val1 < span_start_logits[b_i, j]:
                    span_start_argmax[b_i] = j
                    val1 = span_start_logits[b_i, j]
                val2 = span_end_logits[b_i, j]
                if val1 + val2 > max_span_log_prob[b_i]:
                    if j - span_start_argmax[b_i] > max_span_length:
                        continue
                    best_word_span[b_i, 0] = span_start_argmax[b_i]
                    best_word_span[b_i, 1] = j
                    max_span_log_prob[b_i] = val1 + val2
        for b_i in range(batch_size):
            j = best_word_span[b_i, 1]
            yesno_pred = np.argmax(span_yesno_logits[b_i, j])
            followup_pred = np.argmax(span_followup_logits[b_i, j])
            best_word_span[b_i, 2] = int(yesno_pred)
            best_word_span[b_i, 3] = int(followup_pred)
        return best_word_span
Exemplo n.º 33
0
class BidirectionalAttentionFlow(Model):
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            char_field_embedder: TextFieldEmbedder,
            # num_highway_layers: int,
            phrase_layer: Seq2SeqEncoder,
            char_rnn: Seq2SeqEncoder,
            hops: int,
            hidden_dim: int,
            dropout: float = 0.2,
            mask_lstms: bool = True,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._char_field_embedder = char_field_embedder
        self._features_embedder = nn.Embedding(2, 5)
        # self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim() + 5 * 3,
        #                                               num_highway_layers))
        self._phrase_layer = phrase_layer
        self._encoding_dim = phrase_layer.get_output_dim()
        # self._stacked_brnn = PytorchSeq2SeqWrapper(
        #     StackedBidirectionalLstm(input_size=self._encoding_dim, hidden_size=hidden_dim,
        #                              num_layers=3, recurrent_dropout_probability=0.2))
        self._char_rnn = char_rnn

        self.hops = hops

        self.interactive_aligners = nn.ModuleList()
        self.interactive_SFUs = nn.ModuleList()
        self.self_aligners = nn.ModuleList()
        self.self_SFUs = nn.ModuleList()
        self.aggregate_rnns = nn.ModuleList()
        for i in range(hops):
            # interactive aligner
            self.interactive_aligners.append(
                layers.SeqAttnMatch(self._encoding_dim))
            self.interactive_SFUs.append(
                layers.SFU(self._encoding_dim, 3 * self._encoding_dim))
            # self aligner
            self.self_aligners.append(layers.SelfAttnMatch(self._encoding_dim))
            self.self_SFUs.append(
                layers.SFU(self._encoding_dim, 3 * self._encoding_dim))
            # aggregating
            self.aggregate_rnns.append(
                PytorchSeq2SeqWrapper(
                    nn.LSTM(input_size=self._encoding_dim,
                            hidden_size=hidden_dim,
                            num_layers=1,
                            dropout=0.2,
                            bidirectional=True,
                            batch_first=True)))

        # Memmory-based Answer Pointer
        self.mem_ans_ptr = layers.MemoryAnsPointer(x_size=self._encoding_dim,
                                                   y_size=self._encoding_dim,
                                                   hidden_size=hidden_dim,
                                                   hop=hops,
                                                   dropout_rate=0.2,
                                                   normalize=True)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno: torch.IntTensor = None,
            question_tf: torch.FloatTensor = None,
            passage_tf: torch.FloatTensor = None,
            q_em_cased: torch.IntTensor = None,
            p_em_cased: torch.IntTensor = None,
            q_em_uncased: torch.IntTensor = None,
            p_em_uncased: torch.IntTensor = None,
            q_in_lemma: torch.IntTensor = None,
            p_in_lemma: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        x1_c_emb = self._dropout(self._char_field_embedder(passage))
        x2_c_emb = self._dropout(self._char_field_embedder(question))

        # embedded_question = torch.cat([self._dropout(self._text_field_embedder(question)),
        #                                self._features_embedder(q_em_cased),
        #                                self._features_embedder(q_em_uncased),
        #                                self._features_embedder(q_in_lemma),
        #                                question_tf.unsqueeze(2)], dim=2)
        # embedded_passage = torch.cat([self._dropout(self._text_field_embedder(passage)),
        #                               self._features_embedder(p_em_cased),
        #                               self._features_embedder(p_em_uncased),
        #                               self._features_embedder(p_in_lemma),
        #                               passage_tf.unsqueeze(2)], dim=2)
        token_emb_q = self._dropout(self._text_field_embedder(question))
        token_emb_c = self._dropout(self._text_field_embedder(passage))
        token_emb_question, q_ner_and_pos = torch.split(token_emb_q, [300, 40],
                                                        dim=2)
        token_emb_passage, p_ner_and_pos = torch.split(token_emb_c, [300, 40],
                                                       dim=2)
        question_word_features = torch.cat([
            q_ner_and_pos,
            self._features_embedder(q_em_cased),
            self._features_embedder(q_em_uncased),
            self._features_embedder(q_in_lemma),
            question_tf.unsqueeze(2)
        ],
                                           dim=2)
        passage_word_features = torch.cat([
            p_ner_and_pos,
            self._features_embedder(p_em_cased),
            self._features_embedder(p_em_uncased),
            self._features_embedder(p_in_lemma),
            passage_tf.unsqueeze(2)
        ],
                                          dim=2)

        # embedded_question = self._highway_layer(embedded_q)
        # embedded_passage = self._highway_layer(embedded_q)

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        char_features_c = self._char_rnn(
            x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1),
                              x1_c_emb.size(2), x1_c_emb.size(3))),
            passage_lstm_mask.unsqueeze(2).repeat(
                1, 1, x1_c_emb.size(2)).reshape(
                    (x1_c_emb.size(0) * x1_c_emb.size(1),
                     x1_c_emb.size(2)))).reshape(
                         (x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2),
                          -1))[:, :, -1, :]
        char_features_q = self._char_rnn(
            x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1),
                              x2_c_emb.size(2), x2_c_emb.size(3))),
            question_lstm_mask.unsqueeze(2).repeat(
                1, 1, x2_c_emb.size(2)).reshape(
                    (x2_c_emb.size(0) * x2_c_emb.size(1),
                     x2_c_emb.size(2)))).reshape(
                         (x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2),
                          -1))[:, :, -1, :]

        # token_emb_q, char_emb_q, question_word_features = torch.split(embedded_question, [300, 300, 56], dim=2)
        # token_emb_c, char_emb_c, passage_word_features = torch.split(embedded_passage, [300, 300, 56], dim=2)

        # char_features_q = self._char_rnn(char_emb_q, question_lstm_mask)
        # char_features_c = self._char_rnn(char_emb_c, passage_lstm_mask)

        emb_question = torch.cat(
            [token_emb_question, char_features_q, question_word_features],
            dim=2)
        emb_passage = torch.cat(
            [token_emb_passage, char_features_c, passage_word_features], dim=2)

        encoded_question = self._dropout(
            self._phrase_layer(emb_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(emb_passage, passage_lstm_mask))

        batch_size = encoded_question.size(0)
        passage_length = encoded_passage.size(1)

        encoding_dim = encoded_question.size(-1)

        # c_check = self._stacked_brnn(encoded_passage, passage_lstm_mask)
        # q = self._stacked_brnn(encoded_question, question_lstm_mask)
        c_check = encoded_passage
        q = encoded_question
        for i in range(self.hops):
            q_tilde = self.interactive_aligners[i].forward(
                c_check, q, question_mask)
            c_bar = self.interactive_SFUs[i].forward(
                c_check,
                torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2))
            c_tilde = self.self_aligners[i].forward(c_bar, passage_mask)
            c_hat = self.self_SFUs[i].forward(
                c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde],
                                 2))
            c_check = self.aggregate_rnns[i].forward(c_hat, passage_mask)

        # Predict
        start_scores, end_scores, yesno_scores = self.mem_ans_ptr.forward(
            c_check, q, passage_mask, question_mask)

        best_span, yesno_predict, loc = self.get_best_span(
            start_scores, end_scores, yesno_scores)

        output_dict = {
            "span_start_logits": start_scores,
            "span_end_logits": end_scores,
            "best_span": best_span
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(start_scores, span_start.squeeze(-1))
            self._span_start_accuracy(start_scores, span_start.squeeze(-1))
            loss += nll_loss(end_scores, span_end.squeeze(-1))
            self._span_end_accuracy(end_scores, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

            gold_span_end_loc = []
            span_end = span_end.view(batch_size).squeeze().data.cpu().numpy()
            for i in range(batch_size):
                gold_span_end_loc.append(
                    max(span_end[i] + i * passage_length, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(_yesno, yesno.view(-1), ignore_index=-1)

            pred_span_end_loc = []
            for i in range(batch_size):
                pred_span_end_loc.append(max(loc[i], 0))
            predicted_end = span_start.new(pred_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(0,
                                                           predicted_end).view(
                                                               -1, 3)
            self._span_yesno_accuracy(_yesno, yesno.squeeze(-1))

            output_dict['loss'] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['yesno'] = yesno_predict
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
            'start_acc': self._span_start_accuracy.get_metric(reset),
            'end_acc': self._span_end_accuracy.get_metric(reset),
            'span_acc': self._span_accuracy.get_metric(reset),
            "yesno": self._span_yesno_accuracy.get_metric(reset),
            'em': exact_match,
            'f1': f1_score,
        }

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        yesno_tags = [
            self.vocab.get_token_from_index(x, namespace="yesno_labels")
            for x in output_dict.pop("yesno")
        ]
        output_dict['yesno'] = yesno_tags
        return output_dict

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor,
                      yesno_scores: torch.Tensor):
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError(
                "Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2),
                                                     dtype=torch.long)
        yesno_predict = span_start_logits.new_zeros(batch_size,
                                                    dtype=torch.long)
        loc = yesno_scores.new_zeros(batch_size, dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()
        yesno_logits = yesno_scores.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
                    yesno_predict[b] = int(np.argmax(yesno_logits[b, j]))
                    loc[b] = j + passage_length * b
        return best_word_span, yesno_predict, loc
Exemplo n.º 34
0
class BidirectionalAttentionFlow(Model):
    """
    This class implements Minjoon Seo's `Bidirectional Attention Flow model
    <https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_
    for answering reading comprehension questions (ICLR 2017).

    The basic layout is pretty simple: encode words as a combination of word embeddings and a
    character-level encoder, pass the word representations through a bi-LSTM/GRU, use a matrix of
    attentions to put question information into the passage word representations (this is the only
    part that is at all non-standard), pass this through another few layers of bi-LSTMs/GRUs, and
    do a softmax over span start and span end.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model.
    num_highway_layers : ``int``
        The number of highway layers to use in between embedding the input and passing it through
        the phrase layer.
    phrase_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and doing the bidirectional attention.
    similarity_function : ``SimilarityFunction``
        The similarity function that we will use when comparing encoded passage and question
        representations.
    modeling_layer : ``Seq2SeqEncoder``
        The encoder (with its own internal stacking) that we will use in between the bidirectional
        attention and predicting span start and end.
    span_end_encoder : ``Seq2SeqEncoder``
        The encoder that we will use to incorporate span start predictions into the passage state
        before predicting span end.
    dropout : ``float``, optional (default=0.2)
        If greater than 0, we will apply dropout with this probability after all encoders (pytorch
        LSTMs do not apply dropout to their last layer).
    mask_lstms : ``bool``, optional (default=True)
        If ``False``, we will skip passing the mask to the LSTM layers.  This gives a ~2x speedup,
        with only a slight performance decrease, if any.  We haven't experimented much with this
        yet, but have confirmed that we still get very similar performance with much faster
        training times.  We still use the mask for all softmaxes, but avoid the shuffling that's
        required when using masking with pytorch LSTMs.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidirectionalAttentionFlow, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
                                                      num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms

        initializer(self)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._squad_metrics.get_metric(reset)
        return {
                'start_acc': self._span_start_accuracy.get_metric(reset),
                'end_acc': self._span_end_accuracy.get_metric(reset),
                'span_acc': self._span_accuracy.get_metric(reset),
                'em': exact_match,
                'f1': f1_score,
                }

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
        return best_word_span
Exemplo n.º 35
0
class BidafV2(Model):
    """
    The modified version of official bidaf with support for squad v2
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 num_highway_layers: int,
                 phrase_layer: Seq2SeqEncoder,
                 metric: Metric,
                 similarity_function: SimilarityFunction,
                 modeling_layer: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 mask_lstms: bool = True,
                 no_answer: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BidafV2, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._highway_layer = TimeDistributed(Highway(text_field_embedder.get_output_dim(),
                                                      num_highway_layers))
        self._phrase_layer = phrase_layer
        self._matrix_attention = LegacyMatrixAttention(similarity_function)
        self._modeling_layer = modeling_layer
        self._span_end_encoder = span_end_encoder

        encoding_dim = phrase_layer.get_output_dim()
        modeling_dim = modeling_layer.get_output_dim()
        span_start_input_dim = encoding_dim * 4 + modeling_dim
        self._span_start_predictor = TimeDistributed(torch.nn.Linear(span_start_input_dim, 1))

        span_end_encoding_dim = span_end_encoder.get_output_dim()
        span_end_input_dim = encoding_dim * 4 + span_end_encoding_dim
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(span_end_input_dim, 1))

        # Bidaf has lots of layer dimensions which need to match up - these aren't necessarily
        # obvious from the configuration files, so we check here.
        check_dimensions_match(modeling_layer.get_input_dim(), 4 * encoding_dim,
                               "modeling layer input dim", "4 * encoding dim")
        check_dimensions_match(text_field_embedder.get_output_dim(), phrase_layer.get_input_dim(),
                               "text field embedder output dim", "phrase layer input dim")
        check_dimensions_match(span_end_encoder.get_input_dim(), 4 * encoding_dim + 3 * modeling_dim,
                               "span end encoder input dim", "4 * encoding dim + 3 * modeling dim")

        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = metric
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._mask_lstms = mask_lstms
        self._threshold = torch.nn.Parameter(torch.zeros(1, 1))
        self._no_answer = no_answer

        initializer(self)

    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.
        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        # Add no answer padding.
        if self._no_answer:
            # Shape: (batch_size, passage_length + 1)
            passage_eval_mask = torch.cat([passage_mask, passage_mask.new_ones((batch_size, 1))], dim=-1)
            # Shape: (batch_size, 1)
            threshold = self._threshold.expand(batch_size, 1)
            # Shape: (batch_size, passage_length + 1)
            span_start_logits = torch.cat([span_start_logits, threshold], dim=-1)
            span_end_logits = torch.cat([span_end_logits, threshold], dim=-1)
        else:
            passage_eval_mask = passage_mask
        span_start_logits = util.replace_masked_values(span_start_logits, passage_eval_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_eval_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits, self._no_answer)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None and span_end is not None:
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            # In case there is no answer, convert span_start and span_end from -1 to passage_length
            if self._no_answer:
                span_start = torch.tensor(span_start)# pylint: disable=not-callable
                span_end = torch.tensor(span_end)# pylint: disable=not-callable
                for i in range(batch_size):
                    if span_start[i][0] == -1:
                        span_start[i][0] = passage_length
                        span_end[i][0] = passage_length

            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_eval_mask), span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_eval_mask), span_end.squeeze(-1))
            output_dict["loss"] = loss


        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                if predicted_span[0] < 0:
                    best_span_string = ''
                else:
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        Output metrics include em, f1, no_em, span_acc
        """
        ret: Dict[str, Any] = {}
        if self._no_answer:
            ret = self._squad_metrics.get_metric(reset)
        else:
            exact_match, f1_score = self._squad_metrics.get_metric(reset)
            ret['em'] = exact_match
            ret['f1'] = f1_score
        ret['span_acc'] = self._span_accuracy.get_metric(reset)
        return ret

    @staticmethod
    def get_best_span(span_start_logits: torch.Tensor,
                      span_end_logits: torch.Tensor,
                      no_answer: bool = False) -> torch.Tensor:
        """
        Output best span (st, ed) where span_start_logits[st] + span_end_logits[ed] (st<=ed)
        is maximized
        if no_answer set to True, span_start_logits[-1] + span_end_logits[-1] will be checked
        seprately, if this value is max, return (-1, -1)
        """
        if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
            raise ValueError("Input shapes must be (batch_size, passage_length)")
        batch_size, passage_length = span_start_logits.size()
        if no_answer:
            passage_length = passage_length - 1
        max_span_log_prob = [-1e20] * batch_size
        span_start_argmax = [0] * batch_size
        best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long)

        span_start_logits = span_start_logits.detach().cpu().numpy()
        span_end_logits = span_end_logits.detach().cpu().numpy()

        for b in range(batch_size):  # pylint: disable=invalid-name
            for j in range(passage_length):
                val1 = span_start_logits[b, span_start_argmax[b]]
                if val1 < span_start_logits[b, j]:
                    span_start_argmax[b] = j
                    val1 = span_start_logits[b, j]

                val2 = span_end_logits[b, j]

                if val1 + val2 > max_span_log_prob[b]:
                    best_word_span[b, 0] = span_start_argmax[b]
                    best_word_span[b, 1] = j
                    max_span_log_prob[b] = val1 + val2
            if no_answer and max_span_log_prob[b] < span_start_logits[b, -1] + span_end_logits[b, -1]:
                best_word_span[b, 0] = -1
                best_word_span[b, 1] = -1
                max_span_log_prob[b] = span_start_logits[b, -1] + span_end_logits[b, -1]

        return best_word_span