Ejemplo n.º 1
0
    def _compute_antecedent_gold_labels(top_span_labels: torch.IntTensor,
                                        antecedent_labels: torch.IntTensor):
        """
        Generates a binary indicator for every pair of spans. This label is one if and
        only if the pair of spans belong to the same cluster. The labels are augmented
        with a dummy antecedent at the zeroth position, which represents the prediction
        that a span does not have any antecedent.

        Parameters
        ----------
        top_span_labels : ``torch.IntTensor``, required.
            The cluster id label for every span. The id is arbitrary,
            as we just care about the clustering. Has shape (batch_size, num_spans_to_keep).
        antecedent_labels : ``torch.IntTensor``, required.
            The cluster id label for every antecedent span. The id is arbitrary,
            as we just care about the clustering. Has shape
            (batch_size, num_spans_to_keep, max_antecedents).

        Returns
        -------
        pairwise_labels_with_dummy_label : ``torch.FloatTensor``
            A binary tensor representing whether a given pair of spans belong to
            the same cluster in the gold clustering.
            Has shape (batch_size, num_spans_to_keep, max_antecedents + 1).

        """
        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        target_labels = top_span_labels.expand_as(antecedent_labels)
        same_cluster_indicator = (target_labels == antecedent_labels).float()
        non_dummy_indicator = (target_labels >= 0).float()
        pairwise_labels = same_cluster_indicator * non_dummy_indicator

        # Shape: (batch_size, num_spans_to_keep, 1)
        dummy_labels = (1 - pairwise_labels).prod(-1, keepdim=True)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
        pairwise_labels_with_dummy_label = torch.cat([dummy_labels, pairwise_labels], -1)
        return pairwise_labels_with_dummy_label
Ejemplo n.º 2
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            target_word: torch.IntTensor,
            gold_label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """

        Parameters
        ----------
        tokens:

        target_word:
            (batch_size, 2)
        gold_label:
            (batch_size)
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containing the original words in the sentence to be tagged under a 'words' key.

        Returns
        -------
        An output dictionary consisting of:

        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # Shape: (batch_size, sentence_length, embedding_size)
        tokens_embeddings = self._lexical_dropout(
            self._text_field_embedder(tokens))

        # Shape: (batch_size, sentence_length)
        tokens_mask = util.get_text_field_mask(tokens).float()

        # Shape: (batch_size, sentence_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            tokens_embeddings, tokens_mask)

        # Shape: (batch_size, 2 * encoding_dim)
        target_word_embeddings = self._target_word_extractor(
            contextualized_embeddings, target_word)

        # Shape: (batch_size, 1)
        complex_word_logits = self._complex_word_scorer(target_word_embeddings)

        complex_word_predictions = complex_word_logits > 0.5

        output_dict = {
            "logits": complex_word_logits,
            "predictions": complex_word_predictions
        }

        if gold_label is not None:
            output_dict["loss"] = self._loss(complex_word_logits,
                                             gold_label.unsqueeze(-1).float())

            macro_F1 = metrics.f1_score(gold_label,
                                        complex_word_predictions,
                                        average='macro')

            self._metric(complex_word_predictions, gold_label)

        return output_dict
Ejemplo n.º 3
0
    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            # passage: Dict[str, torch.LongTensor],
            passages_list: Dict[str, torch.LongTensor],
            passages_length: torch.LongTensor = None,
            correct_passage: torch.LongTensor = None,
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        #import time
        #t0 = time.time()
        # shape: B x N x T x E
        embedded_passage_list = self._embedder(passages_list)
        # shape: N
        (batch_size, num_passages, max_p,
         embedding_size) = embedded_passage_list.size()

        # shape: B x Tq x E
        embedded_question = self._embedder(question)
        embedded_passage = embedded_passage_list.view(batch_size, -1,
                                                      embedding_size)
        # embedded_passage = self._embedder(passage)

        # batch_size = embedded_question.size(0)
        total_passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question)
        # passage_mask = util.get_text_field_mask(passage)
        passage_list_mask = util.get_text_field_mask(passages_list, 1)
        passage_mask = passage_list_mask.view(batch_size, -1)
        #pdb.set_trace()
        # shape: B x T x 2H
        encoded_question = self._dropout(
            self._question_encoder(embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._passage_encoder(embedded_passage, passage_mask))
        passage_mask = passage_mask.float()
        question_mask = question_mask.float()

        encoding_dim = encoded_question.size(-1)
        #encoded_passage_list = self._dropout(self._passage_encoder(embedded_passage_list, passage_list_mask))
        #pdb.set_trace()
        # shape: B x 2H
        if encoded_passage.is_cuda:
            cuda_device = encoded_passage.get_device()
            gru_hidden = Variable(
                torch.zeros(batch_size, encoding_dim).cuda(cuda_device))
        else:
            gru_hidden = Variable(torch.zeros(batch_size, encoding_dim))

        question_awared_passage = []
        for timestep in range(total_passage_length):
            u_t_P = encoded_passage[:, timestep, :]
            # shape: B x Tq = attention(B x 2H, B x Tq x 2H)
            attn_weights = self._question_attention_for_passage(
                encoded_passage[:, timestep, :], encoded_question,
                question_mask)
            # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq)
            attended_question = util.weighted_sum(encoded_question,
                                                  attn_weights)
            # shape: B x 4H
            passage_question_combined = torch.cat(
                [encoded_passage[:, timestep, :], attended_question], dim=-1)
            # shape: B x 4H
            gate = F.sigmoid(self._gate(passage_question_combined))
            gru_input = gate * passage_question_combined
            # shape: B x 2H
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            question_awared_passage.append(gru_hidden)

        # shape: B x T x 2H
        # question aware passage representation v_P
        question_awared_passage = torch.stack(question_awared_passage, dim=1)

        # compute question vector r_Q
        # shape: B x T = attention(B x 2H, B x T x 2H)
        v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim)
        attn_weights = self._question_attention_for_question(
            v_r_Q_tiled, encoded_question, question_mask)
        # shape: B x 2H
        r_Q = util.weighted_sum(encoded_question, attn_weights)
        # shape: B x T = attention(B x 2H, B x T x 2H)
        span_start_logits = self._passage_attention_for_answer(
            r_Q, question_awared_passage, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_start_log_probs = util.masked_log_softmax(span_start_logits,
                                                       passage_mask)
        # shape: B x 2H
        c_t = util.weighted_sum(question_awared_passage, span_start_probs)
        # shape: B x 2H
        h_1 = self._dropout(self._answer_net(c_t, r_Q))

        span_end_logits = self._passage_attention_for_answer(
            h_1, question_awared_passage, passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_end_log_probs = util.masked_log_softmax(span_end_logits,
                                                     passage_mask)

        #num_passages = passages_length.size(1)
        #cum_passages = torch.cumsum(passages_length, dim=1)
        g = []
        for i in range(num_passages):
            attn_weights = self._passage_attention_for_ranking(
                r_Q, question_awared_passage[:, i * max_p:(i + 1) * max_p, :],
                passage_mask[:, i * max_p:(i + 1) * max_p])
            r_P = util.weighted_sum(
                question_awared_passage[:, i * max_p:(i + 1) * max_p, :],
                attn_weights)
            question_passage_combined = torch.cat([r_Q, r_P], dim=-1)
            gi = self._dropout(
                self._match_layer_2(
                    F.tanh(self._match_layer_1(question_passage_combined))))
            g.append(gi)
        # compute r_P
        # shape: B x T = attention(B x 2H, B x T x 2H)
        #attn_weights = self._passage_attention_for_ranking(r_Q, question_awared_passage, passage_mask)
        # shape: B x 2H
        #r_P = util.weighted_sum(question_awared_passage, attn_weights)
        # shape: B x 4H
        #question_passage_combined = torch.cat([r_Q, r_P], dim=-1)
        # shape: B x 10
        #g = self._dropout(self._match_layer_2(F.tanh(self._match_layer_1(question_passage_combined))))
        #cum_passages = torch.cumsum(passages_length, dim=1)
        #for b in range(batch_size):
        #    for i in range(num_passages):
        #        attn_weights = self._passage_attention_for_ranking(r_Q[b], question_awared_passage
        #t1 = time.time()
        padded_span_start = span_start.clone()
        padded_span_end = span_end.clone()
        cumsum = torch.cumsum(passage_mask.long(), dim=1)
        for b in range(batch_size):
            padded_span_start[b] = (cumsum[b] == span_start[b] +
                                    1).nonzero()[0][0]
            padded_span_end[b] = (cumsum[b] == span_end[b] + 1).nonzero()[0][0]

        #t2 = time.time()

        g = torch.cat(g, dim=1)
        passage_log_probs = F.log_softmax(g, dim=-1)

        output_dict = {}
        if span_start is not None:
            AP_loss = F.nll_loss(span_start_log_probs, padded_span_start.squeeze(-1)) +\
                F.nll_loss(span_end_log_probs, padded_span_end.squeeze(-1))
            PR_loss = F.nll_loss(passage_log_probs,
                                 correct_passage.squeeze(-1))
            loss = self._r * AP_loss + self._r * PR_loss
            output_dict['loss'] = loss
        #pdb.set_trace()
        _, max_start = torch.max(span_start_probs, dim=1)
        _, max_end = torch.max(span_end_probs, dim=1)
        #max_start = max_start.cpu().data[0]
        #max_end = max_end.cpu().data[0]
        #unpad
        #t3 = time.time()
        for b in range(batch_size):
            max_start.data[b] = cumsum.data[b, max_start.data[b]] - 1
            max_end.data[b] = cumsum.data[b, max_end.data[b]] - 1
        output_dict['span_start_idx'] = max_start
        output_dict['span_end_idx'] = max_end
        #t4 = time.time()
        global ITE
        ITE += 1
        #self._num_iter += 1
        #if (self._num_iter % 50 == 0):
        if (ITE % 100 == 0):
            print(" gold %i:%i|predicted %i:%i" %
                  (span_start.squeeze(-1)[0], span_end.squeeze(-1)[0],
                   max_start.cpu().data[0], max_end.cpu().data[0]))

        #t5 = time.time()
        #print("Total: %.5f" % (t5-t0))
        #print("Batch processing 1: %.5f" % (t2-t1))
        #ZZprint("Batch processing 2: %.5f" % (t4-t3))
        return output_dict
Ejemplo n.º 4
0
    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        p1_answer_marker: torch.IntTensor = None,
        p2_answer_marker: torch.IntTensor = None,
        p3_answer_marker: torch.IntTensor = None,
        yesno_list: torch.IntTensor = None,
        followup_list: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        token_character_ids = question["token_characters"]["token_characters"]
        batch_size, max_qa_count, max_q_len, _ = token_character_ids.size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question,
                                                      num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(
            total_qa_count, max_q_len,
            self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(
            self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1)
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage)

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(
                max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(
                1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(
                batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(
                total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(
                question_num_ind)
            embedded_question = torch.cat(
                [embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = (embedded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1,
                1).view(total_qa_count, passage_length,
                        self._text_field_embedder.get_output_dim()))
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count,
                                                     passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat(
                [repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(
                    total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat(
                    [repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(
                        total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(
                        p3_answer_marker)
                    repeated_embedded_passage = torch.cat(
                        [repeated_embedded_passage, p3_answer_marker_emb],
                        dim=-1)

            repeated_encoded_passage = self._variational_dropout(
                self._phrase_layer(repeated_embedded_passage,
                                   repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(
                self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
                1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(
                total_qa_count, passage_length, self._encoding_dim)

        encoded_question = self._variational_dropout(
            self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(
            repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(total_qa_count, passage_length, self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat(
            [
                repeated_encoded_passage,
                passage_question_vectors,
                repeated_encoded_passage * passage_question_vectors,
                repeated_encoded_passage * tiled_question_passage_vector,
            ],
            dim=-1,
        )

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(
            self._residual_encoder(final_merged_passage,
                                   repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer,
                                                     residual_layer)

        mask = repeated_passage_mask.reshape(
            total_qa_count, passage_length, 1) * repeated_passage_mask.reshape(
                total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              dtype=torch.bool,
                              device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask & ~self_mask

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs,
                                           residual_layer)
        self_attention_vecs = torch.cat([
            self_attention_vecs, residual_layer,
            residual_layer * self_attention_vecs
        ],
                                        dim=-1)
        residual_layer = F.relu(
            self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage,
                                             repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(
            torch.cat([final_merged_passage, start_rep], dim=-1),
            repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(
            -1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(
            span_start_logits,
            span_end_logits,
            span_yesno_logits,
            span_followup_logits,
            self._max_span_length,
        )

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits,
                                        repeated_passage_mask),
                span_start.view(-1),
                ignore_index=-1,
            )
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits,
                                        repeated_passage_mask),
                span_end.view(-1),
                ignore_index=-1,
            )
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(
                best_span[:, 0:2],
                torch.stack([span_start, span_end],
                            -1).view(total_qa_count, 2),
                mask=qa_mask.unsqueeze(1).expand(-1, 2),
            )
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1),
                             followup_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)
            self._span_followup_accuracy(_followup,
                                         followup_list.view(-1),
                                         mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict["best_span_str"] = []
        output_dict["qid"] = []
        output_dict["followup"] = []
        output_dict["yesno"] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]["original_passage"]
            offsets = metadata[i]["token_offsets"]
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad.metric_max_over_ground_truths(
                                    squad.f1_score, best_span_string, refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad.metric_max_over_ground_truths(
                            squad.f1_score, best_span_string, answer_texts)
                self._official_f1(100 * f1_score)
            output_dict["qid"].append(per_dialog_query_id_list)
            output_dict["best_span_str"].append(per_dialog_best_span_list)
            output_dict["yesno"].append(per_dialog_yesno_list)
            output_dict["followup"].append(per_dialog_followup_list)
        return output_dict
Ejemplo n.º 5
0
    def forward(
        self,  # type: ignore
        premise: Dict[str, torch.LongTensor],
        hypothesis: Dict[str, torch.LongTensor],
        # premise_verbs: torch.LongTensor,
        # hypothesis_verbs: torch.LongTensor,
        label: torch.IntTensor = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        premise_verbs : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis_verbs : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._sentence_encoder:
            embedded_premise = self._sentence_encoder(embedded_premise,
                                                      premise_mask)
            embedded_hypothesis = self._sentence_encoder(
                embedded_hypothesis, hypothesis_mask)

        # Compute SRL encoding. Shape: (batch_size, premise|hypothesis-length, SRL encoding size)
        # srl_premise = self._srl_model.forward(premise, premise_verbs)["encoded_text"]
        # srl_hypothesis = self._srl_model.forward(hypothesis, hypothesis_verbs)["encoded_text"]

        # Do not backpropagate through the SRL representation.
        # srl_embedded_premise = torch.cat([embedded_premise, srl_premise.detach()], dim=-1)
        # srl_embedded_hypothesis = torch.cat([embedded_hypothesis, srl_hypothesis.detach()], dim=-1)

        # masked_premise = embedded_premise - float('inf') * (1 - premise_mask.unsqueeze(-1))
        # masked_hypothesis = embedded_hypothesis - float('inf') * (1 - hypothesis_mask.unsqueeze(-1))
        masked_premise = embedded_premise * premise_mask.unsqueeze(-1)
        masked_hypothesis = embedded_hypothesis * hypothesis_mask.unsqueeze(-1)

        # Max pooling along the time dimension.
        # compared_premise, _ = masked_premise.max(dim=1)
        # compared_hypothesis, _ = masked_hypothesis.max(dim=1)
        compared_premise = masked_premise.sum(dim=1)
        compared_hypothesis = masked_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([
            compared_premise, compared_hypothesis,
            torch.abs(compared_premise - compared_hypothesis),
            compared_premise * compared_hypothesis
        ],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
    def forward(
        self,  # type: ignore
        tokens: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata to persist

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).float()

        encoder_output = self._encoder(embedded_text, mask)

        encoded_repr = []
        for aggregation in self._aggregations:
            if aggregation == "meanpool":
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoder_output * broadcast_mask
                encoded_text = masked_mean(context_vectors,
                                           broadcast_mask,
                                           dim=1,
                                           keepdim=False)
            elif aggregation == 'maxpool':
                broadcast_mask = mask.unsqueeze(-1).float()
                context_vectors = encoder_output * broadcast_mask
                encoded_text = masked_max(context_vectors,
                                          broadcast_mask,
                                          dim=1)
            elif aggregation == 'final_state':
                is_bi = self._encoder.is_bidirectional()
                encoded_text = get_final_encoder_states(
                    encoder_output, mask, is_bi)
            encoded_repr.append(encoded_text)

        encoded_repr = torch.cat(encoded_repr, 1)

        if self.dropout:
            encoded_repr = self.dropout(encoded_repr)

        output_hidden = self._output_feedforward(encoded_repr)
        label_logits = self._classification_layer(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Ejemplo n.º 7
0
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            span_starts: torch.IntTensor,
            span_ends: torch.IntTensor,
            span_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        span_starts : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 1), representing the start indices of
            candidate spans for mentions. Comes from a ``ListField[IndexField]`` of indices into
            the text of the document.
        span_ends : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 1), representing the end indices of
            candidate spans for mentions. Comes from a ``ListField[IndexField]`` of indices into
            the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = span_starts.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans, 1)
        span_mask = (span_starts >= 0).float()
        # IndexFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        span_starts = F.relu(span_starts.float()).long()
        span_ends = F.relu(span_ends.float()).long()
        # Shape: (batch_size, num_spans, 2)
        span_indices = torch.cat([span_starts, span_ends], -1)

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, span_indices)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, span_indices)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask.squeeze(-1), num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        top_span_starts = util.batched_index_select(span_starts,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_span_ends = util.batched_index_select(span_ends, top_span_indices,
                                                  flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)
        # Compute final predictions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = torch.cat([top_span_starts, top_span_ends], -1)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.last_dim_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood
        return output_dict
Ejemplo n.º 8
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Ejemplo n.º 9
0
    def forward(  # type: ignore
        self,
        premise: TextFieldTensors,
        hypothesis: TextFieldTensors,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        premise : TextFieldTensors
            From a `TextField`
        hypothesis : TextFieldTensors
            From a `TextField`
        label : torch.IntTensor, optional (default = None)
            From a `LabelField`
        metadata : `List[Dict[str, Any]]`, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        # Returns

        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise)
        hypothesis_mask = get_text_field_mask(hypothesis)

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis,
                                           hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise,
                                                   encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
            [
                encoded_premise,
                attended_hypothesis,
                encoded_premise - attended_hypothesis,
                encoded_premise * attended_hypothesis,
            ],
            dim=-1,
        )
        hypothesis_enhanced = torch.cat(
            [
                encoded_hypothesis,
                attended_premise,
                encoded_hypothesis - attended_premise,
                encoded_hypothesis * attended_premise,
            ],
            dim=-1,
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(
            premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(
            hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(
                projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(
                projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise,
                                       premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis,
                                       hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(v_ai, premise_mask.unsqueeze(-1),
                                           -1e7).max(dim=1)
        v_b_max, _ = replace_masked_values(v_bi, hypothesis_mask.unsqueeze(-1),
                                           -1e7).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1),
                            dim=1) / torch.sum(premise_mask, 1, keepdim=True)
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1),
                            dim=1) / torch.sum(
                                hypothesis_mask, 1, keepdim=True)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            #passages_length: torch.LongTensor = None,
            #correct_passage: torch.LongTensor = None,
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata=None) -> Dict[str, torch.Tensor]:

        # shape: B x Tq x E
        embedded_question = self._embedder(question)
        embedded_passage = self._embedder(passage)

        batch_size = embedded_question.size(0)
        total_passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        # shape: B x T x 2H
        encoded_question = self._dropout(
            self._question_encoder(embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._passage_encoder(embedded_passage, passage_mask))
        passage_mask = passage_mask.float()
        question_mask = question_mask.float()

        encoding_dim = encoded_question.size(-1)

        # shape: B x 2H
        if encoded_passage.is_cuda:
            cuda_device = encoded_passage.get_device()
            gru_hidden = Variable(
                torch.zeros(batch_size, encoding_dim).cuda(cuda_device))
        else:
            gru_hidden = Variable(torch.zeros(batch_size, encoding_dim))

        question_awared_passage = []
        for timestep in range(total_passage_length):
            # shape: B x Tq = attention(B x 2H, B x Tq x 2H)
            attn_weights = self._question_attention_for_passage(
                encoded_passage[:, timestep, :], encoded_question,
                question_mask)
            # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq)
            attended_question = util.weighted_sum(encoded_question,
                                                  attn_weights)
            # shape: B x 4H
            passage_question_combined = torch.cat(
                [encoded_passage[:, timestep, :], attended_question], dim=-1)
            # shape: B x 4H
            gate = F.sigmoid(self._gate(passage_question_combined))
            gru_input = gate * passage_question_combined
            # shape: B x 2H
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            question_awared_passage.append(gru_hidden)

        # shape: B x T x 2H
        # question aware passage representation v_P
        question_awared_passage = torch.stack(question_awared_passage, dim=1)

        self_attended_passage = []
        for timestep in range(total_passage_length):
            attn_weights = self._passage_self_attention(
                question_awared_passage[:, timestep, :],
                question_awared_passage, passage_mask)
            attended_passage = util.weighted_sum(question_awared_passage,
                                                 attn_weights)
            input_combined = torch.cat(
                [question_awared_passage[:, timestep, :], attended_passage],
                dim=-1)
            gate = F.sigmoid(self._self_gate(input_combined))
            gru_input = gate * input_combined
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            self_attended_passage.append(gru_hidden)

        self_attended_passage = torch.stack(self_attended_passage, dim=1)

        # compute question vector r_Q
        # shape: B x T = attention(B x 2H, B x T x 2H)
        v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim)
        attn_weights = self._question_attention_for_question(
            v_r_Q_tiled, encoded_question, question_mask)
        # shape: B x 2H
        r_Q = util.weighted_sum(encoded_question, attn_weights)
        # shape: B x T = attention(B x 2H, B x T x 2H)
        span_start_logits = self._passage_attention_for_answer(
            r_Q, self_attended_passage, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_start_log_probs = util.masked_log_softmax(span_start_logits,
                                                       passage_mask)
        # shape: B x 2H
        c_t = util.weighted_sum(self_attended_passage, span_start_probs)
        # shape: B x 2H
        h_1 = self._dropout(self._answer_net(c_t, r_Q))

        span_end_logits = self._passage_attention_for_answer(
            h_1, self_attended_passage, passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_end_log_probs = util.masked_log_softmax(span_end_logits,
                                                     passage_mask)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        #num_passages = passages_length.size(1)
        #acc = Variable(torch.zeros(batch_size, num_passages + 1)).cuda(cuda_device).long()

        #acc[:, 1:num_passages+1] = torch.cumsum(passages_length, dim=1)

        #g_batch = []
        #for b in range(batch_size):
        #    g = []
        #    for i in range(num_passages):
        #        if acc[b, i+1].data[0] > acc[b, i].data[0]:
        #            attn_weights = self._passage_attention_for_ranking(r_Q[b:b+1], question_awared_passage[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0], :], passage_mask[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0]])
        #            r_P = util.weighted_sum(question_awared_passage[b:b+1, acc[b, i].data[0]:acc[b, i+1].data[0], :], attn_weights)
        #            question_passage_combined = torch.cat([r_Q[b:b+1], r_P], dim=-1)
        #            gi = self._dropout(self._match_layer_2(F.tanh(self._dropout(self._match_layer_1(question_passage_combined)))))
        #            g.append(gi)
        #        else:
        #            g.append(Variable(torch.zeros(1, 1)).cuda(cuda_device))
        #    g = torch.cat(g, dim=1)
        #    g_batch.append(g)

        #t2 = time.time()
        #g = torch.cat(g_batch, dim=0)
        output_dict = {}
        if span_start is not None:
            AP_loss = F.nll_loss(span_start_log_probs, span_start.squeeze(-1)) +\
                F.nll_loss(span_end_log_probs, span_end.squeeze(-1))
            #PR_loss = F.nll_loss(passage_log_probs, correct_passage.squeeze(-1))
            #loss = self._r * AP_loss + self._r * PR_loss
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict['loss'] = AP_loss

        _, max_start = torch.max(span_start_probs, dim=1)
        _, max_end = torch.max(span_end_probs, dim=1)
        #t3 = time.time()
        output_dict['span_start_idx'] = max_start
        output_dict['span_end_idx'] = max_end
        #t4 = time.time()
        #global ITE
        #ITE += 1
        #if (ITE % 100 == 0):
        #    print(" gold %i:%i|predicted %i:%i" %(span_start.squeeze(-1)[0], span_end.squeeze(-1)[0], max_start.data[0], max_end.data[0]))
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens

        #t5 = time.time()
        #print("Total: %.5f" % (t5-t0))
        #print("Batch processing 1: %.5f" % (t2-t1))
        #print("Batch processing 2: %.5f" % (t4-t3))
        return output_dict
    def forward(
        self,  # type: ignore
        sentences: torch.LongTensor,
        labels: torch.IntTensor = None,
        confidences: torch.Tensor = None,
        additional_features: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences
        sep_id = self.vocab.get_token_index('[SEP]', 'bert')

        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = embedded_sentences.size()

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            sentences_mask = sentences[
                'bert'] == sep_id  # mask for all the SEP tokens in the batch
            embedded_sentences = embedded_sentences[
                sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
            # returns num_sentences_per_batch x vector_len
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(dim=0)
            embedded_sentences = self.dropout(embedded_sentences)

            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[
                    labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(
                        f'Found {num_labels} labels but {num_sentences} sentences'
                    )
                    labels = labels[:
                                    num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                    # We are ignoring this problem for now.
                    # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:
                                                                  num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = embedded_sentences.size()
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat(
                (embedded_sentences, additional_features), dim=-1)

        label_logits = self.time_distributed_aggregate_feedforward(
            embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            flattened_logits = label_logits.view((batch_size * num_sentences),
                                                 self.num_labels)
            flattened_gold = labels.contiguous().view(-1)

            if not self.with_crf:
                label_loss = self.loss(flattened_logits.squeeze(),
                                       flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(
                        label_loss).view(-1)
                label_loss = label_loss.mean()
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels,
                                          mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view(
                    (batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                evaluation_mask = (flattened_gold != -1)
                self.label_accuracy(flattened_probs.float().contiguous(),
                                    flattened_gold.squeeze(-1),
                                    mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(
                        namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs,
                           flattened_gold,
                           mask=evaluation_mask)

        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict
Ejemplo n.º 12
0
    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            bias_tokens: TextFieldTensors = None,
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        sentence_pair_logits = self._classification_layer(embedded_text)

        # If we're training, also compute loss and accuracy for the bias-only model
        if not self.evaluation_mode and bias_tokens is not None:
            # Make predictions with hypothesis only
            embedded_text = self._text_field_embedder(bias_tokens)
            mask = get_text_field_mask(bias_tokens)

            if self._seq2seq_encoder:
                embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)
            embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

            if self._dropout:
                embedded_text = self._dropout(embedded_text)

            if self._feedforward_hyp_only is not None:
                embedded_text = self._feedforward_hyp_only(embedded_text)
            hyp_only_logits = self._classification_layer_hyp_only(
                embedded_text)

            log_probs_pair = torch.log_softmax(sentence_pair_logits, dim=1)
            log_probs_hyp = torch.log_softmax(hyp_only_logits, dim=1)

            # Combine with product of experts (normalized log space sum)
            # Do not require gradients from hyp-only classifier
            combined = log_probs_pair + log_probs_hyp.detach()

            # NLL loss over combined labels
            loss = self._nll_loss(combined, label.long().view(-1))
            hyp_loss = self._nll_loss(log_probs_hyp, label.long().view(-1))
            self._accuracy(combined, label)
            self._hyp_only_accuracy(hyp_only_logits, label)

            output_dict = {"loss": loss + self._beta * hyp_loss}

            return output_dict

        else:
            loss = self._cross_ent_loss(sentence_pair_logits, label)
            self._accuracy(sentence_pair_logits, label)
            return {
                "loss": loss,
                "logits": sentence_pair_logits,
                "probs": torch.softmax(sentence_pair_logits, dim=1)
            }
Ejemplo n.º 13
0
    def forward(
            self,
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno_list: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, max_qa_count, max_q_len, _ = question[
            'token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(yesno_list, 0).view(total_qa_count)

        # GloVe and simple cnn char embedding, embedding dim = 100 + 100 = 200
        word_emb_ques = self.tokens_embedder(
            question,
            num_wrapping_dims=1).reshape(total_qa_count, max_q_len,
                                         self.tokens_embedder.get_output_dim())
        word_emb_pass = self.tokens_embedder(passage)

        # Elmo embedding, embedding dim = 1024
        elmo_ques = self.elmo_embedder(question, num_wrapping_dims=1).reshape(
            total_qa_count, max_q_len, self.elmo_embedder.get_output_dim())
        elmo_pass = self.elmo_embedder(passage)

        # Passage features embedding, embedding dim = 20 + 20 = 40
        pass_feat = self.features_embedder(passage)

        # GloVe + cnn + Elmo
        embedded_question = self._variational_dropout(
            torch.cat([word_emb_ques, elmo_ques], dim=2))
        embedded_passage = self._variational_dropout(
            torch.cat([word_emb_pass, elmo_pass], dim=2))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question,
                                                 num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(
            1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            total_qa_count, passage_length)

        # Concatenate Elmo after encoded passage
        encode_passage = self._phrase_layer(embedded_passage, passage_mask)
        projected_passage = self.relu(
            self.projected_layer(torch.cat([encode_passage, elmo_pass],
                                           dim=2)))

        # Concatenate Elmo after encoded question
        encode_question = self._phrase_layer(embedded_question, question_mask)
        projected_question = self.relu(
            self.projected_layer(torch.cat([encode_question, elmo_ques],
                                           dim=2)))

        encoded_passage = self._variational_dropout(projected_passage)
        repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(
            1, max_qa_count, 1, 1)
        repeated_encoded_passage = repeated_encoded_passage.view(
            total_qa_count, passage_length, self._encoding_dim)
        repeated_pass_feat = (pass_feat.unsqueeze(1).repeat(
            1, max_qa_count, 1, 1)).view(total_qa_count, passage_length, 40)
        encoded_question = self._variational_dropout(projected_question)

        # total_qa_count * max_q_len * passage_length
        # cnt * m * n
        s = torch.bmm(encoded_question,
                      repeated_encoded_passage.transpose(2, 1))
        alpha = util.masked_softmax(s,
                                    question_mask.unsqueeze(2).expand(
                                        s.size()),
                                    dim=1)
        # cnt * n * h
        aligned_p = torch.bmm(alpha.transpose(2, 1), encoded_question)

        # cnt * m * n
        beta = util.masked_softmax(s,
                                   repeated_passage_mask.unsqueeze(1).expand(
                                       s.size()),
                                   dim=2)
        # cnt * m * h
        aligned_q = torch.bmm(beta, repeated_encoded_passage)

        fused_p = self.fuse_p(repeated_encoded_passage, aligned_p)
        fused_q = self.fuse_q(encoded_question, aligned_q)

        # add manual features here
        q_aware_p = self._variational_dropout(
            self.projected_lstm(
                torch.cat([fused_p, repeated_pass_feat], dim=2),
                repeated_passage_mask))

        # cnt * n * n
        # self_p = torch.bmm(q_aware_p, q_aware_p.transpose(2, 1))
        # self_p = self.bilinear_self_align(q_aware_p)
        self_p = self._self_attention(q_aware_p, q_aware_p)
        mask = repeated_passage_mask.reshape(
            total_qa_count, passage_length, 1) * repeated_passage_mask.reshape(
                total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length,
                              passage_length,
                              device=self_p.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        lamb = util.masked_softmax(self_p, mask, dim=2)
        # lamb = util.masked_softmax(self_p, repeated_passage_mask, dim=2)
        # cnt * n * h
        self_aligned_p = torch.bmm(lamb, q_aware_p)

        # cnt * n * h
        fused_self_p = self.fuse_s(q_aware_p, self_aligned_p)
        contextual_p = self._variational_dropout(
            self.contextual_layer_p(fused_self_p, repeated_passage_mask))
        # contextual_p = self.contextual_layer_p(fused_self_p, repeated_passage_mask)

        contextual_q = self._variational_dropout(
            self.contextual_layer_q(fused_q, question_mask))
        # contextual_q = self.contextual_layer_q(fused_q, question_mask)
        # cnt * m
        gamma = util.masked_softmax(
            self.linear_self_align(contextual_q).squeeze(2),
            question_mask,
            dim=1)
        # cnt * h
        weighted_q = torch.bmm(gamma.unsqueeze(1), contextual_q).squeeze(1)

        span_start_logits = self.bilinear_layer_s(weighted_q, contextual_p)
        span_end_logits = self.bilinear_layer_e(weighted_q, contextual_p)

        # cnt * n * 1  cnt * 1 * h
        span_yesno_logits = self.yesno_predictor(
            torch.bmm(span_end_logits.unsqueeze(2), weighted_q.unsqueeze(1)))
        # span_yesno_logits = self.yesno_predictor(contextual_p)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits,
                                                       span_end_logits,
                                                       span_yesno_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss for training

        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits,
                                                    repeated_passage_mask),
                            span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits,
                                      span_start.view(-1),
                                      mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask),
                             span_end.view(-1),
                             ignore_index=-1)
            self._span_end_accuracy(span_end_logits,
                                    span_end.view(-1),
                                    mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end],
                                            -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(
                total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(
                    max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(
                    max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(torch.nn.functional.log_softmax(_yesno, dim=-1),
                             yesno_list.view(-1),
                             ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(
                0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno,
                                      yesno_list.view(-1),
                                      mask=qa_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"],
                        metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count +
                                                     per_dialog_query_index])
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                yesno_pred = predicted_span[2]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_query_id_list.append(iid)
                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(
                                squad_eval.metric_max_over_ground_truths(
                                    squad_eval.f1_score, best_span_string,
                                    refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(
                            squad_eval.f1_score, best_span_string,
                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
        return output_dict
Ejemplo n.º 14
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.
        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Ejemplo n.º 15
0
    def forward(  # type: ignore
            self,
            query: TextFieldTensors,  # batch * words
            documents: TextFieldTensors,  # batch * num_documents * words
            labels: torch.IntTensor = None,  # batch * num_documents,
            **kwargs) -> Dict[str, torch.Tensor]:
        embedded_text = self._text_field_embedder(query)
        mask = get_text_field_mask(query).long()

        embedded_documents = self._text_field_embedder(documents,
                                                       num_wrapping_dims=1)
        documents_mask = get_text_field_mask(documents).long()

        if self._dropout:
            embedded_text = self._dropout(embedded_text)
            embedded_documents = self._dropout(embedded_documents)
        """
        This isn't exactly a 'hack', but it's definitely not the most efficient way to do it.
        Our matcher expects a single (query, document) pair, but we have (query, [d_0, ..., d_n]).
        To get around this, we expand the query embeddings to create these pairs, and then
        flatten both into the 3D tensor [batch*num_documents, words, dim] expected by the matcher. 
        The expansion does this:

        [
            (q_0, [d_{0,0}, ..., d_{0,n}]), 
            (q_1, [d_{1,0}, ..., d_{1,n}])
        ]
        =>
        [
            [ (q_0, d_{0,0}), ..., (q_0, d_{0,n}) ],
            [ (q_1, d_{1,0}), ..., (q_1, d_{1,n}) ]
        ]

        Which we then flatten along the batch dimension. It would likely be more efficient
        to rewrite the matrix multiplications in the relevance matchers, but this is a more general solution.
        """

        embedded_text = embedded_text.unsqueeze(1).expand(
            -1, embedded_documents.size(1), -1,
            -1)  # [batch, num_documents, words, dim]
        mask = mask.unsqueeze(1).expand(-1, embedded_documents.size(1), -1)

        scores = self._relevance_matcher(embedded_text, embedded_documents,
                                         mask, documents_mask).squeeze(-1)
        probs = torch.sigmoid(scores)

        output_dict = {"logits": scores, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            query)
        if labels is not None:
            label_mask = (labels != -1)

            self._mrr(probs, labels, label_mask)
            self._ndcg(probs, labels, label_mask)

            probs = probs.view(-1)
            labels = labels.view(-1)
            label_mask = label_mask.view(-1)

            self._auc(probs, labels.ge(0.5).long(), label_mask)

            loss = self._loss(probs, labels)
            output_dict["loss"] = loss.masked_fill(~label_mask,
                                                   0).sum() / label_mask.sum()

        output_dict.update(kwargs)
        return output_dict
Ejemplo n.º 16
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],
                dim=-1
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True
        )

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Ejemplo n.º 17
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits,
                       "label_probs": label_probs,
                       "h2p_attention": h2p_attention,
                       "p2h_attention": p2h_attention}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata]
            output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata]

        return output_dict
Ejemplo n.º 18
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_starts: torch.IntTensor = None,
                span_ends: torch.IntTensor = None,
                yesno_labels : torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, num_of_passage_tokens = passage['bert'].size()

        # Executing the BERT model on the word piece ids (input_ids)
        input_ids = passage['bert']
        token_type_ids = torch.zeros_like(input_ids)
        mask = (input_ids != 0).long()
        embedded_chunk, pooled_output = \
            self._text_field_embedder.token_embedder_bert.bert_model(input_ids=util.combine_initial_dims(input_ids),
                                                         token_type_ids=util.combine_initial_dims(token_type_ids),
                                                         attention_mask=util.combine_initial_dims(mask),
                                                         output_all_encoded_layers=False)

        # Just measuring some lengths and offsets to handle the converstion between tokens and word-pieces
        passage_length = embedded_chunk.size(1)
        mask_min_values, wordpiece_passage_lens = torch.min(mask, dim=1)
        wordpiece_passage_lens[mask_min_values == 1] = mask.shape[1]
        offset_min_values, token_passage_lens = torch.min(passage['bert-offsets'], dim=1)
        token_passage_lens[offset_min_values != 0] = passage['bert-offsets'].shape[1]
        bert_offsets = passage['bert-offsets'].cpu().numpy()

        # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of
        # start and end spans.
        logits = self.qa_outputs(embedded_chunk)
        start_logits, end_logits = logits.split(1, dim=-1)
        span_start_logits = start_logits.squeeze(-1)
        span_end_logits = end_logits.squeeze(-1)

        # all input is preprocessed before farword is run, counting the yesno vocabulary
        # will indicate if yesno support is at all needed.
        if self.vocab.get_vocab_size("yesno_labels") > 1:
            yesno_logits = self.qa_yesno(torch.max(embedded_chunk, 1)[0])

        span_starts.clamp_(0, passage_length)
        span_ends.clamp_(0, passage_length)

        # moving to word piece indexes from token indexes of start and end span
        span_starts_list = [bert_offsets[i, span_starts[i]] if span_starts[i] != 0 else 0 for i in range(batch_size)]
        span_ends_list = [bert_offsets[i, span_ends[i]] if span_ends[i] != 0 else 0 for i in range(batch_size)]
        span_starts = torch.cuda.LongTensor(span_starts_list, device=span_end_logits.device) \
            if torch.cuda.is_available() else torch.LongTensor(span_starts_list)
        span_ends = torch.cuda.LongTensor(span_ends_list, device=span_end_logits.device) \
            if torch.cuda.is_available() else torch.LongTensor(span_ends_list)

        loss_fct = CrossEntropyLoss(ignore_index=passage_length)
        start_loss = loss_fct(start_logits.squeeze(-1), span_starts)
        end_loss = loss_fct(end_logits.squeeze(-1), span_ends)

        if self.vocab.get_vocab_size("yesno_labels") > 1 and yesno_labels is not None:
            yesno_loss = loss_fct(yesno_logits, yesno_labels)
            loss = (start_loss + end_loss + yesno_loss) / 3
        else:
            loss = (start_loss + end_loss) / 2

        output_dict: Dict[str, Any] = {}
        if loss == 0:
            # For evaluation purposes only!
            output_dict["loss"] = torch.cuda.FloatTensor([0], device=span_end_logits.device) \
                if torch.cuda.is_available() else torch.FloatTensor([0])
        else:
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['best_span_logit'] = []
        output_dict['cannot_answer_logit'] = []
        output_dict['yesno'] = []
        output_dict['yesno_logit'] = []
        output_dict['qid'] = []
        if span_starts is not None:
            output_dict['EM'] = []
            output_dict['f1'] = []

        # getting best span prediction for
        best_span = self._get_example_predications(span_start_logits, span_end_logits, self._max_span_length)
        best_span_cpu = best_span.detach().cpu().numpy()

        for instance_ind, instance_metadata in zip(range(batch_size), metadata):
            best_span_logit = span_start_logits.data.cpu().numpy()[instance_ind, best_span_cpu[instance_ind][0]] + \
                              span_end_logits.data.cpu().numpy()[instance_ind, best_span_cpu[instance_ind][1]]
            cannot_answer_logit = span_start_logits.data.cpu().numpy()[instance_ind, 0] + \
                              span_end_logits.data.cpu().numpy()[instance_ind, 0]

            if self.vocab.get_vocab_size("yesno_labels") > 1:
                yesno_maxind = np.argmax(yesno_logits[instance_ind].data.cpu().numpy())
                yesno_logit = yesno_logits[instance_ind, yesno_maxind].data.cpu().numpy()
                yesno_pred = self.vocab.get_token_from_index(yesno_maxind, namespace="yesno_labels")
            else:
                yesno_pred = 'no_yesno'
                yesno_logit = -30.0

            passage_str = instance_metadata['original_passage']
            offsets = instance_metadata['token_offsets']

            predicted_span = best_span_cpu[instance_ind]
            # In this version yesno if not "no_yesno" will be regarded as final answer before the spans are considered.
            if yesno_pred != 'no_yesno':
                best_span_string = yesno_pred
            else:
                if cannot_answer_logit + 0.9 > best_span_logit :
                    best_span_string = 'cannot_answer'
                else:
                    wordpiece_offsets = self.bert_offsets_to_wordpiece_offsets(bert_offsets[instance_ind][0:len(offsets)])
                    start_offset = offsets[wordpiece_offsets[predicted_span[0] if predicted_span[0] < len(wordpiece_offsets) \
                        else len(wordpiece_offsets)-1]][0]
                    end_offset = offsets[wordpiece_offsets[predicted_span[1] if predicted_span[1] < len(wordpiece_offsets) \
                        else len(wordpiece_offsets)-1]][1]
                    best_span_string = passage_str[start_offset:end_offset]

            output_dict['best_span_str'].append(best_span_string)
            output_dict['cannot_answer_logit'].append(cannot_answer_logit)
            output_dict['best_span_logit'].append(best_span_logit)
            output_dict['yesno'].append(yesno_pred)
            output_dict['yesno_logit'].append(yesno_logit)
            output_dict['qid'].append(instance_metadata['question_id'])

            # In AllenNLP prediction mode we have no gold answers, so let's check
            if span_starts is not None:
                yesno_label_ind = yesno_labels.data.cpu().numpy()[instance_ind]
                yesno_label = self.vocab.get_token_from_index(yesno_label_ind, namespace="yesno_labels")

                if yesno_label != 'no_yesno':
                    gold_answer_texts = [yesno_label]
                elif instance_metadata['cannot_answer']:
                    gold_answer_texts = ['cannot_answer']
                else:
                    gold_answer_texts = instance_metadata['answer_texts_list']

                f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, gold_answer_texts)
                EM_score = squad_eval.metric_max_over_ground_truths(squad_eval.exact_match_score, best_span_string, gold_answer_texts)
                self._official_f1(100 * f1_score)
                self._official_EM(100 * EM_score)
                output_dict['EM'].append(100 * EM_score)
                output_dict['f1'].append(100 * f1_score)


        return output_dict
Ejemplo n.º 19
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            premise_tags: torch.LongTensor,
            hypothesis: Dict[str, torch.LongTensor],
            hypothesis_tags: torch.LongTensor,
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        premise_tags : torch.LongTensor
            The POS tags of the premise.
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``.
        hypothesis_tags: torch.LongTensor
            The POS tags of the hypothesis.
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._attention(projected_premise,
                                            projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        # running the parser
        encoded_p_parse, p_parse_mask = self._parser(premise, premise_tags)
        p_parse_encoder_final_state = get_final_encoder_states(
            encoded_p_parse, p_parse_mask)
        encoded_h_parse, h_parse_mask = self._parser(hypothesis,
                                                     hypothesis_tags)
        h_parse_encoder_final_state = get_final_encoder_states(
            encoded_h_parse, h_parse_mask)

        compared_premise = torch.cat(
            [compared_premise, p_parse_encoder_final_state], dim=-1)
        compared_hypothesis = torch.cat(
            [compared_hypothesis, h_parse_encoder_final_state], dim=-1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {'logits': label_logits, 'label_probs': label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict['loss'] = loss

        if metadata is not None:
            output_dict['premise_tokens'] = [
                x['premise_tokens'] for x in metadata
            ]
            output_dict['hypothesis_tokens'] = [
                x['hypothesis_tokens'] for x in metadata
            ]

        return output_dict
Ejemplo n.º 20
0
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        choices: Dict[str, torch.LongTensor],
        evidence: Dict[str, torch.LongTensor],
        answer_index: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        qa_pairs : Dict[str, torch.LongTensor]
            From a ``ListField``.
        answer_index : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is what we are trying to predict.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, question and choices for each instance
            in the batch. The length of this list should be the batch size, and each dictionary
            should have the keys ``qid``, ``question``, ``choices``, ``question_tokens`` and
            ``choices_tokens``.

        Returns
        -------
        An output dictionary consisting of the followings.

        qid : List[str]
            A list consisting of question ids.
        answer_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_options=5)`` representing unnormalised log
            probabilities of the choices.
        answer_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_options=5)`` representing probabilities of the
            choices.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # batch, seq_len -> batch, seq_len, emb
        question_hidden = self._bert(question)
        batch_size, emb_size = question_hidden.size(0), question_hidden.size(2)
        question_hidden = question_hidden[..., 0, :]  # batch, emb
        # batch, 5, seq_len -> batch, 5, seq_len, emb
        choice_hidden = self._bert(choices, num_wrapping_dims=1)
        choice_hidden = choice_hidden[..., 0, :]  # batch, 5, emb
        # batch, 5, evi_num, seq_len -> batch, 5, evi_num, seq_len, emb
        evidence_hidden = self._bert(evidence, num_wrapping_dims=2)
        evi_num = evidence_hidden.size(2)
        evidence_hidden = evidence_hidden[..., 0, :]  # batch, 5, evi_num, emb

        if self.dropout:
            question_hidden = self.dropout(question_hidden)
            choice_hidden = self.dropout(choice_hidden)
            evidence_hidden = self.dropout(evidence_hidden)

        # batch, 5, evi_num, emb -> batch, 5 x evi_num, emb
        evidence_hidden = evidence_hidden.view(batch_size, -1, emb_size)

        scores = self.attention(question_hidden, evidence_hidden)
        scores = scores.view(batch_size, 5, evi_num)  # batch, 5, evi_num

        evidence_hidden = evidence_hidden.view(batch_size, 5, evi_num,
                                               emb_size)
        # evidence_hidden: batch, 5, evi_num, emb
        # scores: batch, 5, evi_num
        evidence_summary = weighted_sum(evidence_hidden,
                                        scores)  # batch, 5, emb

        question_hidden = question_hidden.unsqueeze(1).expand(
            batch_size, 5, emb_size)
        cls_hidden = torch.cat(
            [question_hidden, choice_hidden, evidence_summary], dim=-1)

        # the final MLP -- apply dropout to input, and MLP applies to hidden
        answer_logits = self._classifier(cls_hidden).squeeze(-1)
        answer_probs = torch.nn.functional.softmax(answer_logits, dim=-1)
        qids = [m['qid'] for m in metadata]
        output_dict = {
            "answer_logits": answer_logits,
            "answer_probs": answer_probs,
            "qid": qids
        }

        if answer_index is not None:
            answer_index = answer_index.squeeze(-1)
            loss = self._loss(answer_logits, answer_index)
            self._accuracy(answer_logits, answer_index)
            output_dict["loss"] = loss
        return output_dict
Ejemplo n.º 21
0
    def forward(
            self,  # type: ignore
            bert_input: Dict[str, torch.LongTensor],
            sim_bert_input: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        if self.use_scenario_encoding:
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_bert_input_token_labels_wp = sim_bert_input[
                'scenario_gold_encoding']
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input)
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_input_mask_wp = (sim_bert_input['bert'] != 0).float()
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[
                'bert-type-ids'].float()  # works only with one [SEP]
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze(
                2)
            # Shape: (batch_size, passage_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_passage_representation_wp[:,
                                                                          sim_passage_mask_wp
                                                                          .sum(
                                                                              dim
                                                                              =0
                                                                          ) >
                                                                          0, :]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:,
                                                                         sim_passage_mask_wp
                                                                         .sum(
                                                                             dim
                                                                             =0
                                                                         ) > 0]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_mask_wp = sim_passage_mask_wp[:,
                                                      sim_passage_mask_wp.sum(
                                                          dim=0) > 0]

            # Shape: (batch_size, passage_len_wp, 4)
            sim_token_logits_wp = self._sim_token_label_predictor(
                sim_passage_representation_wp)

            if span_start is not None:  # during training and validation
                class_weights = torch.tensor(self.sim_class_weights,
                                             device=sim_token_logits_wp.device,
                                             dtype=torch.float)
                sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4),
                                         sim_passage_token_labels_wp.view(-1),
                                         ignore_index=0,
                                         weight=class_weights)
                self._sim_loss_metric(sim_loss.item())
                self._sim_yes_f1(sim_token_logits_wp,
                                 sim_passage_token_labels_wp,
                                 sim_passage_mask_wp)
                self._sim_no_f1(sim_token_logits_wp,
                                sim_passage_token_labels_wp,
                                sim_passage_mask_wp)
                if self.sim_pretraining:
                    return {'loss': sim_loss}

            if not self.sim_pretraining:
                # Shape: (batch_size, passage_len_wp)
                bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax(
                    dim=2)) * sim_passage_mask_wp.long()
                # Shape: (batch_size, bert_input_len_wp)
                bert_input_wp_len = bert_input['history_encoding'].size(1)
                if bert_input['scenario_encoding'].size(1) > bert_input_wp_len:
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = bert_input[
                        'scenario_encoding'][:, :bert_input_wp_len]
                else:
                    batch_size = bert_input['scenario_encoding'].size(0)
                    difference = bert_input_wp_len - bert_input[
                        'scenario_encoding'].size(1)
                    zeros = torch.zeros(
                        batch_size,
                        difference,
                        dtype=bert_input['scenario_encoding'].dtype,
                        device=bert_input['scenario_encoding'].device)
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = torch.cat(
                        [bert_input['scenario_encoding'], zeros], dim=1)

        # Shape: (batch_size, bert_input_len + 1, embedding_dim)
        bert_output = self._text_field_embedder(bert_input)
        # Shape: (batch_size, embedding_dim)
        pooled_output = bert_output[:, 0]
        # Shape: (batch_size, bert_input_len, embedding_dim)
        bert_output = bert_output[:, 1:, :]
        # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len)
        passage_representation, passage_mask = self.get_passage_representation(
            bert_output, bert_input)

        # Shape: (batch_size, 4)
        action_logits = self._action_predictor(pooled_output)
        # Shape: (batch_size, passage_len, 2)
        span_logits = self._span_predictor(passage_representation)
        # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1)
        span_start_logits, span_end_logits = span_logits.split(1, dim=2)
        # Shape: (batch_size, passage_len)
        span_start_logits = span_start_logits.squeeze(2)
        # Shape: (batch_size, passage_len)
        span_end_logits = span_end_logits.squeeze(2)

        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "pooled_output": pooled_output,
            "passage_representation": passage_representation,
            "action_logits": action_logits,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        if self.use_scenario_encoding:
            output_dict["sim_token_logits"] = sim_token_logits_wp

        # Compute the loss for training (and for validation)
        if span_start is not None:
            # Shape: (batch_size,)
            span_loss = nll_loss(util.masked_log_softmax(
                span_start_logits, passage_mask),
                                 span_start.squeeze(1),
                                 reduction='none')
            # Shape: (batch_size,)
            span_loss += nll_loss(util.masked_log_softmax(
                span_end_logits, passage_mask),
                                  span_end.squeeze(1),
                                  reduction='none')
            # Shape: (batch_size,)
            more_mask = (label == self.vocab.get_token_index(
                'More', namespace="labels")).float()
            # Shape: (batch_size,)
            span_loss = (span_loss * more_mask).sum() / (more_mask.sum() +
                                                         1e-6)
            if more_mask.sum() > 1e-7:
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(1), more_mask)
                self._span_end_accuracy(span_end_logits, span_end.squeeze(1),
                                        more_mask)
                # Shape: (batch_size, 2)
                span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long()
                self._span_accuracy(best_span,
                                    torch.cat([span_start, span_end], dim=1),
                                    span_acc_mask)

            action_loss = cross_entropy(action_logits, label)
            self._action_accuracy(action_logits, label)

            self._span_loss_metric(span_loss.item())
            self._action_loss_metric(action_loss.item())
            output_dict['loss'] = self.loss_weights[
                'span_loss'] * span_loss + self.loss_weights[
                    'action_loss'] * action_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if not self.training:  # true during validation and test
            output_dict['best_span_str'] = []
            batch_size = len(metadata)
            for i in range(batch_size):
                passage_text = metadata[i]['passage_text']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_str = passage_text[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_str)
                if 'gold_span' in metadata[i]:
                    if metadata[i]['action'] == 'More':
                        gold_span = metadata[i]['gold_span']
                        self._squad_metrics(best_span_str, [gold_span])
        return output_dict
Ejemplo n.º 22
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            yesno: torch.IntTensor = None,
            question_tf: torch.FloatTensor = None,
            passage_tf: torch.FloatTensor = None,
            q_em_cased: torch.IntTensor = None,
            p_em_cased: torch.IntTensor = None,
            q_em_uncased: torch.IntTensor = None,
            p_em_uncased: torch.IntTensor = None,
            q_in_lemma: torch.IntTensor = None,
            p_in_lemma: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        x1_c_emb = self._dropout(self._char_field_embedder(passage))
        x2_c_emb = self._dropout(self._char_field_embedder(question))

        # embedded_question = torch.cat([self._dropout(self._text_field_embedder(question)),
        #                                self._features_embedder(q_em_cased),
        #                                self._features_embedder(q_em_uncased),
        #                                self._features_embedder(q_in_lemma),
        #                                question_tf.unsqueeze(2)], dim=2)
        # embedded_passage = torch.cat([self._dropout(self._text_field_embedder(passage)),
        #                               self._features_embedder(p_em_cased),
        #                               self._features_embedder(p_em_uncased),
        #                               self._features_embedder(p_in_lemma),
        #                               passage_tf.unsqueeze(2)], dim=2)
        token_emb_q = self._dropout(self._text_field_embedder(question))
        token_emb_c = self._dropout(self._text_field_embedder(passage))
        token_emb_question, q_ner_and_pos = torch.split(token_emb_q, [300, 40],
                                                        dim=2)
        token_emb_passage, p_ner_and_pos = torch.split(token_emb_c, [300, 40],
                                                       dim=2)
        question_word_features = torch.cat([
            q_ner_and_pos,
            self._features_embedder(q_em_cased),
            self._features_embedder(q_em_uncased),
            self._features_embedder(q_in_lemma),
            question_tf.unsqueeze(2)
        ],
                                           dim=2)
        passage_word_features = torch.cat([
            p_ner_and_pos,
            self._features_embedder(p_em_cased),
            self._features_embedder(p_em_uncased),
            self._features_embedder(p_in_lemma),
            passage_tf.unsqueeze(2)
        ],
                                          dim=2)

        # embedded_question = self._highway_layer(embedded_q)
        # embedded_passage = self._highway_layer(embedded_q)

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        char_features_c = self._char_rnn(
            x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1),
                              x1_c_emb.size(2), x1_c_emb.size(3))),
            passage_lstm_mask.unsqueeze(2).repeat(
                1, 1, x1_c_emb.size(2)).reshape(
                    (x1_c_emb.size(0) * x1_c_emb.size(1),
                     x1_c_emb.size(2)))).reshape(
                         (x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2),
                          -1))[:, :, -1, :]
        char_features_q = self._char_rnn(
            x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1),
                              x2_c_emb.size(2), x2_c_emb.size(3))),
            question_lstm_mask.unsqueeze(2).repeat(
                1, 1, x2_c_emb.size(2)).reshape(
                    (x2_c_emb.size(0) * x2_c_emb.size(1),
                     x2_c_emb.size(2)))).reshape(
                         (x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2),
                          -1))[:, :, -1, :]

        # token_emb_q, char_emb_q, question_word_features = torch.split(embedded_question, [300, 300, 56], dim=2)
        # token_emb_c, char_emb_c, passage_word_features = torch.split(embedded_passage, [300, 300, 56], dim=2)

        # char_features_q = self._char_rnn(char_emb_q, question_lstm_mask)
        # char_features_c = self._char_rnn(char_emb_c, passage_lstm_mask)

        emb_question = torch.cat(
            [token_emb_question, char_features_q, question_word_features],
            dim=2)
        emb_passage = torch.cat(
            [token_emb_passage, char_features_c, passage_word_features], dim=2)

        encoded_question = self._dropout(
            self._phrase_layer(emb_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(emb_passage, passage_lstm_mask))

        batch_size = encoded_question.size(0)
        passage_length = encoded_passage.size(1)

        encoding_dim = encoded_question.size(-1)

        # c_check = self._stacked_brnn(encoded_passage, passage_lstm_mask)
        # q = self._stacked_brnn(encoded_question, question_lstm_mask)
        c_check = encoded_passage
        q = encoded_question
        for i in range(self.hops):
            q_tilde = self.interactive_aligners[i].forward(
                c_check, q, question_mask)
            c_bar = self.interactive_SFUs[i].forward(
                c_check,
                torch.cat([q_tilde, c_check * q_tilde, c_check - q_tilde], 2))
            c_tilde = self.self_aligners[i].forward(c_bar, passage_mask)
            c_hat = self.self_SFUs[i].forward(
                c_bar, torch.cat([c_tilde, c_bar * c_tilde, c_bar - c_tilde],
                                 2))
            c_check = self.aggregate_rnns[i].forward(c_hat, passage_mask)

        # Predict
        start_scores, end_scores, yesno_scores = self.mem_ans_ptr.forward(
            c_check, q, passage_mask, question_mask)

        best_span, yesno_predict, loc = self.get_best_span(
            start_scores, end_scores, yesno_scores)

        output_dict = {
            "span_start_logits": start_scores,
            "span_end_logits": end_scores,
            "best_span": best_span
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(start_scores, span_start.squeeze(-1))
            self._span_start_accuracy(start_scores, span_start.squeeze(-1))
            loss += nll_loss(end_scores, span_end.squeeze(-1))
            self._span_end_accuracy(end_scores, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))

            gold_span_end_loc = []
            span_end = span_end.view(batch_size).squeeze().data.cpu().numpy()
            for i in range(batch_size):
                gold_span_end_loc.append(
                    max(span_end[i] + i * passage_length, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(
                0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(_yesno, yesno.view(-1), ignore_index=-1)

            pred_span_end_loc = []
            for i in range(batch_size):
                pred_span_end_loc.append(max(loc[i], 0))
            predicted_end = span_start.new(pred_span_end_loc)
            _yesno = yesno_scores.view(-1, 3).index_select(0,
                                                           predicted_end).view(
                                                               -1, 3)
            self._span_yesno_accuracy(_yesno, yesno.squeeze(-1))

            output_dict['loss'] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['yesno'] = yesno_predict
        return output_dict
Ejemplo n.º 23
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis0: Dict[str, torch.LongTensor],
            hypothesis1: Dict[str, torch.LongTensor],
            hypothesis2: Dict[str, torch.LongTensor] = None,
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        hyps = [
            h for h in [hypothesis0, hypothesis1, hypothesis2] if h is not None
        ]
        if isinstance(self._text_field_embedder, ElmoTokenEmbedder):
            self._text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states(
            )

        embedded_premise = self._text_field_embedder(premise)

        embedded_hypotheses = []
        for hypothesis in hyps:
            if isinstance(self._text_field_embedder, ElmoTokenEmbedder):
                self._text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states(
                )
            embedded_hypotheses.append(self._text_field_embedder(hypothesis))

        premise_mask = get_text_field_mask(premise).float()
        hypothesis_masks = [
            get_text_field_mask(hypothesis).float() for hypothesis in hyps
        ]
        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypotheses = [
                self.rnn_input_dropout(hyp) for hyp in embedded_hypotheses
            ]

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)

        label_logits = []
        for i, (embedded_hypothesis, hypothesis_mask) in enumerate(
                zip(embedded_hypotheses, hypothesis_masks)):
            encoded_hypothesis = self._encoder(embedded_hypothesis,
                                               hypothesis_mask)

            # Shape: (batch_size, premise_length, hypothesis_length)
            similarity_matrix = self._matrix_attention(encoded_premise,
                                                       encoded_hypothesis)

            # Shape: (batch_size, premise_length, hypothesis_length)
            p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
            # Shape: (batch_size, premise_length, embedding_dim)
            attended_hypothesis = weighted_sum(encoded_hypothesis,
                                               p2h_attention)

            # Shape: (batch_size, hypothesis_length, premise_length)
            h2p_attention = masked_softmax(
                similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
            # Shape: (batch_size, hypothesis_length, embedding_dim)
            attended_premise = weighted_sum(encoded_premise, h2p_attention)

            # the "enhancement" layer
            premise_enhanced = torch.cat([
                encoded_premise, attended_hypothesis, encoded_premise -
                attended_hypothesis, encoded_premise * attended_hypothesis
            ],
                                         dim=-1)
            hypothesis_enhanced = torch.cat([
                encoded_hypothesis, attended_premise, encoded_hypothesis -
                attended_premise, encoded_hypothesis * attended_premise
            ],
                                            dim=-1)

            # embedding -> lstm w/ do -> enhanced attention -> dropout_proj, only if ELMO -> ff proj -> lstm w/ do -> dropout -> ff 300 -> dropout -> output

            # add dropout here with ELMO

            # the projection layer down to the model dimension
            # no dropout in projection
            projected_enhanced_premise = self._projection_feedforward(
                premise_enhanced)
            projected_enhanced_hypothesis = self._projection_feedforward(
                hypothesis_enhanced)

            # Run the inference layer
            if self.rnn_input_dropout:
                projected_enhanced_premise = self.rnn_input_dropout(
                    projected_enhanced_premise)
                projected_enhanced_hypothesis = self.rnn_input_dropout(
                    projected_enhanced_hypothesis)
            v_ai = self._inference_encoder(projected_enhanced_premise,
                                           premise_mask)
            v_bi = self._inference_encoder(projected_enhanced_hypothesis,
                                           hypothesis_mask)

            # The pooling layer -- max and avg pooling.
            # (batch_size, model_dim)
            v_a_max, _ = replace_masked_values(v_ai,
                                               premise_mask.unsqueeze(-1),
                                               -1e7).max(dim=1)
            v_b_max, _ = replace_masked_values(v_bi,
                                               hypothesis_mask.unsqueeze(-1),
                                               -1e7).max(dim=1)

            v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1),
                                dim=1) / torch.sum(
                                    premise_mask, 1, keepdim=True)
            v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1),
                                dim=1) / torch.sum(
                                    hypothesis_mask, 1, keepdim=True)

            # Now concat
            # (batch_size, model_dim * 2 * 4)
            v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

            # the final MLP -- apply dropout to input, and MLP applies to output & hidden
            if self.dropout:
                v = self.dropout(v)

            output_hidden = self._output_feedforward(v)
            logit = self._output_logit(output_hidden)
            assert logit.size(-1) == 1
            label_logits.append(logit)

        label_logits = torch.cat(label_logits, -1)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
Ejemplo n.º 24
0
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        choices_list: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``
        choices_list : Dict[str, torch.LongTensor]
            From a ``List[TextField]``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # encoded_choices_aggregated = embedd_encode_and_aggregate_list_text_field(choices_list,
        #                                                                          self._text_field_embedder,
        #                                                                          self._embeddings_dropout,
        #                                                                          self._choice_encoder,
        #                                                                          self._choice_aggregate)  # # bs, choices, hs
        #
        # encoded_question_aggregated, _ = embedd_encode_and_aggregate_text_field(question, self._text_field_embedder,
        #                                                                         self._embeddings_dropout,
        #                                                                         self._question_encoder,
        #                                                                         self._question_aggregate,
        #                                                                         get_last_states=False)  # bs, hs
        #
        # q_to_choices_att = self._matrix_attention_question_to_choice(encoded_question_aggregated.unsqueeze(1),
        #                                                              encoded_choices_aggregated).squeeze()
        #
        # label_logits = q_to_choices_att
        # label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
        #
        # output_dict = {"label_logits": label_logits, "label_probs": label_probs}
        #
        # if label is not None:
        #     loss = self._loss(label_logits, label.long().view(-1))
        #     self._accuracy(label_logits, label.squeeze(-1))
        #     output_dict["loss"] = loss

        embedded_question = self._text_field_embedder(question)
        embedded_choices = self._text_field_embedder(choices_list)
        question_mask = get_text_field_mask(question).float()
        choices_mask_3d = get_text_field_mask(choices_list,
                                              num_wrapping_dims=1).float()

        # apply dropout for LSTM
        if self._embeddings_dropout:
            embedded_question = self._embeddings_dropout(embedded_question)
            embedded_choices = self._embeddings_dropout(embedded_choices)

        batch_size, choices_cnt, choices_tokens_cnt, emb_size = tuple(
            embedded_choices.shape)
        choices_mask_flattened = choices_mask_3d.view(
            [batch_size * choices_cnt, choices_tokens_cnt])

        # Shape: (batch_size * choices_cnt, choices_tokens_cnt, embedding_size)
        embedded_choices_flattened = embedded_choices.view(
            [batch_size * choices_cnt, choices_tokens_cnt, -1])

        # encode question and choices

        # Shape: (batch_size, question_tokens_cnt, encoder_out_size)
        encoded_question = self._question_encoder(embedded_question,
                                                  question_mask)
        question_tokens_cnt = encoded_question.shape[1]
        encoder_out_size = encoded_question.shape[2]

        # tile to choices tokens
        # Shape: (batch_size, choices_cnt, question_tokens_cnt, encoder_out_size)
        encoded_question = encoded_question.unsqueeze(1).expand(
            batch_size, choices_cnt, question_tokens_cnt,
            encoder_out_size).contiguous()

        # Shape: (batch_size * choices_cnt, question_tokens_cnt, encoder_out_size)
        encoded_question = encoded_question.view(
            [batch_size * choices_cnt, question_tokens_cnt,
             encoder_out_size]).contiguous()

        # tile to choices tokens
        # Shape: (batch_size, choices_cnt, question_length)
        question_mask = question_mask.unsqueeze(1).expand(
            batch_size, choices_cnt, question_tokens_cnt).contiguous()

        # Shape: (batch_size * choices_cnt, question_length)
        question_mask = question_mask.view(
            [batch_size * choices_cnt, question_tokens_cnt]).contiguous()

        # encode choices
        # Shape: (batch_size * choices_cnt, choices_tokens_cnt, encoder_out_size)
        encoded_choices = self._choice_encoder(embedded_choices_flattened,
                                               choices_mask_flattened)
        choices_mask = choices_mask_flattened

        # Shape: (batch_size * choices_cnt, question_length, choices_length)
        similarity_matrix = self._matrix_attention(encoded_question,
                                                   encoded_choices)

        # Shape: (batch_size, question_length, choices_length)
        p2h_attention = last_dim_softmax(similarity_matrix, choices_mask)
        # Shape: (batch_size, question_length, embedding_dim)
        attended_choices = weighted_sum(encoded_choices, p2h_attention)

        # Shape: (batch_size, choices_length, question_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), question_mask)
        # Shape: (batch_size, choices_length, embedding_dim)
        attended_question = weighted_sum(encoded_question, h2p_attention)

        # the "enhancement" layer
        question_enhanced = torch.cat([
            encoded_question, attended_choices, encoded_question -
            attended_choices, encoded_question * attended_choices
        ],
                                      dim=-1)
        choices_enhanced = torch.cat([
            encoded_choices, attended_question, encoded_choices -
            attended_question, encoded_choices * attended_question
        ],
                                     dim=-1)

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_question = self._projection_feedforward(
            question_enhanced)
        projected_enhanced_choices = self._projection_feedforward(
            choices_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_question = self.rnn_input_dropout(
                projected_enhanced_question)
            projected_enhanced_choices = self.rnn_input_dropout(
                projected_enhanced_choices)
        v_ai = self._inference_encoder(projected_enhanced_question,
                                       question_mask)
        v_bi = self._inference_encoder(projected_enhanced_choices,
                                       choices_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(v_ai, question_mask.unsqueeze(-1),
                                           -1e7).max(dim=1)
        v_b_max, _ = replace_masked_values(v_bi, choices_mask.unsqueeze(-1),
                                           -1e7).max(dim=1)

        v_a_avg = torch.sum(v_ai * question_mask.unsqueeze(-1),
                            dim=1) / torch.sum(question_mask, 1, keepdim=True)
        v_b_avg = torch.sum(v_bi * choices_mask.unsqueeze(-1),
                            dim=1) / torch.sum(choices_mask, 1, keepdim=True)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_logits = label_logits.view([batch_size, choices_cnt])
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Ejemplo n.º 25
0
    def forward(  # type: ignore
        self,
        premise: TextFieldTensors,
        hypothesis: TextFieldTensors,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        premise : TextFieldTensors
            From a `TextField`
        hypothesis : TextFieldTensors
            From a `TextField`
        label : torch.IntTensor, optional, (default = None)
            From a `LabelField`
        metadata : `List[Dict[str, Any]]`, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        # Returns

        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape `(batch_size, num_labels)` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,
                                                   projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "h2p_attention": h2p_attention,
            "p2h_attention": p2h_attention,
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [
                x["premise_tokens"] for x in metadata
            ]
            output_dict["hypothesis_tokens"] = [
                x["hypothesis_tokens"] for x in metadata
            ]

        return output_dict
    def forward(
        self,  # type: ignore
        tokens: Dict[str, torch.LongTensor],
        token_type_ids: torch.LongTensor,
        links_tokens: Dict[str, torch.LongTensor],
        links_token_type_ids: torch.LongTensor,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        debug = False
        # batch_size, L
        input_ids = tokens['tokens']
        # batch_size, L
        input_mask = (input_ids != 0).long()

        # shape: batch_size*num_choices, max_len
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
        flat_attention_mask = input_mask.view(-1, input_mask.size(-1))

        # shape: batch_size*num_choices, hidden_dim
        _, pooled_ph = self.bert_model(input_ids=flat_input_ids,
                                       token_type_ids=flat_token_type_ids,
                                       attention_mask=flat_attention_mask)

        # batch_size,num_of_choices, N_P, N_P-1,L1
        links = links_tokens['tokens']
        # batch_size,num_of_choices, N_P, N_P-1,L1
        link_type_ids = links_token_type_ids
        batch_size, num_of_choices, max_premise, c, link_len = links.size()
        links = links.view(-1, link_len)
        link_type_ids = link_type_ids.view(-1, link_len)
        link_mask = (links != 0).long()
        # batch_size*N_P* N_P-1*choices,  CLS embedding
        _, pooled_links = self.link_model(input_ids=links,
                                          token_type_ids=link_type_ids,
                                          attention_mask=link_mask)

        projected_links = self._projection_layer(pooled_links)
        projected_links = projected_links.view(batch_size * num_of_choices,
                                               max_premise, max_premise - 1,
                                               -1)

        reduced_link_mask, _ = torch.max(link_mask.view(
            batch_size * num_of_choices, max_premise, max_premise - 1, -1),
                                         dim=-1,
                                         keepdim=False)

        # batch_size, N_p, CLS embedding
        link_maxpooled, _ = replace_masked_values(
            projected_links, reduced_link_mask.unsqueeze(-1), -1e7).max(dim=-2)
        print(link_maxpooled.size())
        # batch_size, N_P
        average_link_mask, _ = torch.max(reduced_link_mask,
                                         dim=-1,
                                         keepdim=False)
        average_link_mask[average_link_mask == 0] = 0.001
        cuda_device = self._get_prediction_device()
        average_link_mask = average_link_mask.double().float().cuda(
            cuda_device)
        print(average_link_mask.size())
        link_max_summary, _ = replace_masked_values(
            link_maxpooled, average_link_mask.unsqueeze(-1), -1e7).max(dim=-2)
        link_avg_summary = torch.sum(
            link_maxpooled * average_link_mask.unsqueeze(-1), -2) / torch.sum(
                average_link_mask, dim=1, keepdim=True)

        # compute enhanced key
        if debug:
            print(pooled_ph.size())
            print(link_max_summary.size())
            print(link_avg_summary.size())
        pooled = torch.cat((pooled_ph, link_max_summary, link_avg_summary),
                           dim=-1)

        pooled = self._dropout(pooled)

        # apply classification layer
        logits = self._classification_layer(pooled)

        # shape: batch_size,num_choices
        reshaped_logits = logits.view(-1, num_of_choices)
        if debug:
            print(f"reshaped_logits = {reshaped_logits}")

        probs = torch.nn.functional.softmax(reshaped_logits, dim=-1)

        output_dict = {"logits": reshaped_logits, "probs": probs}

        if label is not None:
            loss = self._loss(reshaped_logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(reshaped_logits, label)

        return output_dict
Ejemplo n.º 27
0
def pack_obs(state: Tensor, time: IntTensor) -> Tensor:
    """Reverses the `unpack_obs` transformation."""
    return torch.cat((state, time.float()), dim="R")
Ejemplo n.º 28
0
    def forward(
            self,  # type: ignore
            premise_hypothesis: Dict[str, torch.Tensor],
            dataset: List[str] = None,
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise_hypothesis : Dict[str, torch.LongTensor]
            Combined in a single text field for BERT encoding
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        dataset : List[str]
            Task indicator
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        taskname = "nli-en"
        if dataset is not None:  # TODO: hardcoded; used when not multitask reader was used
            taskname = dataset[0]
        else:
            taskname = "nli-en"

        # xlm case, lang should be bs_seql tensor of lang ids
        mask = get_text_field_mask(premise_hypothesis).float()
        lang = taskname.split("-")[-1]
        lang_id = self._input_embedder._token_embedders[
            "bert"].transformer_model.config.lang2id[lang]
        bs = mask.size()[0]
        seq_len = mask.size()[1]
        lang_ids = mask.new_full((bs, seq_len), lang_id).long()
        if self._feed_lang_ids:
            embedded_combined = self._input_embedder(premise_hypothesis,
                                                     lang=lang_ids)
        else:
            embedded_combined = self._input_embedder(premise_hypothesis)

        if not self._avg:
            pooled_combined = embedded_combined[:, 0, :]
            #pooled_combined = self._pooler(embedded_combined, mask=mask)
        else:
            pooled_combined = embedded_combined.mean(dim=1)

        pooled_combined = self._dropout(pooled_combined)

        logits = self._nli_projection_layer(pooled_combined)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}
        output_dict["cls_emb"] = pooled_combined

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss

            self._nli_per_lang_acc[taskname](logits, label)

        if metadata is not None:
            output_dict["premise_tokens"] = [
                x["premise_tokens"] for x in metadata
            ]
            output_dict["hypothesis_tokens"] = [
                x["hypothesis_tokens"] for x in metadata
            ]

        return output_dict
Ejemplo n.º 29
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))

        # # v5:
        # # remember to set token embeddings in the CONFIG JSON
        # encoded_question = self._dropout(embedded_question)

        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX
        similarity_matrix = self._matrix_attention(encoded_passage,
                                                   encoded_question)

        # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY
        passage_question_attention = util.last_dim_softmax(
            similarity_matrix, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Our custom query2context
        q2c_attention = util.masked_softmax(similarity_matrix,
                                            question_mask,
                                            dim=1).transpose(-1, -2)
        q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention)

        # Now we try the various variants
        # v1:
        # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention)

        # v2:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1]))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2)

        # v3:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim)

        # v4:
        # Re-application of query2context attention
        new_similarity_matrix = self._matrix_attention(encoded_passage,
                                                       q2c_vecs)
        masked_similarity = util.replace_masked_values(
            new_similarity_matrix, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # ------- Original variant
        # # We replace masked values with something really negative here, so they don't affect the
        # # max below.
        # masked_similarity = util.replace_masked_values(similarity_matrix,
        #                                                question_mask.unsqueeze(1),
        #                                                -1e7)
        # # Shape: (batch_size, passage_length)
        # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # # Shape: (batch_size, passage_length)
        # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # # Shape: (batch_size, encoding_dim)
        # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # # Shape: (batch_size, passage_length, encoding_dim)
        # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
        #                                                                             passage_length,
        #                                                                             encoding_dim)

        # ------- END

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        # original beta combination function
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        # # v6:
        # final_merged_passage = torch.cat([tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v7:
        # final_merged_passage = torch.cat([passage_question_vectors],
        #                                  dim=-1)
        #
        # # v8:
        # final_merged_passage = torch.cat([passage_question_vectors,
        #                                   tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v9:
        # final_merged_passage = torch.cat([encoded_passage,
        #                                   passage_question_vectors,
        #                                   encoded_passage * passage_question_vectors],
        #                                  dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Ejemplo n.º 30
0
    def forward(
        self,  # type: ignore
        node_tokens: Dict[str, torch.LongTensor],
        node_token_type_ids: torch.LongTensor,
        links_tokens: Dict[str, torch.LongTensor],
        links_token_type_ids: torch.LongTensor,
        coverage: torch.FloatTensor,
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:

        debug = False

        # batch_size, number_of_choices_ number_of_premises, L
        input_ids = node_tokens['tokens']
        # batch_size, number_of_choices_ number_of_premises, L
        input_mask = (input_ids != 0).long()

        # shape: batch_size*num_choices, max_len
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = node_token_type_ids.view(
            -1, node_token_type_ids.size(-1))
        flat_attention_mask = input_mask.view(-1, input_mask.size(-1))

        # shape: batch_size*num_choices*number_of_premise, hidden_dim
        _, pooled_ph = self.bert_model(input_ids=flat_input_ids,
                                       token_type_ids=flat_token_type_ids,
                                       attention_mask=flat_attention_mask)
        reshaped_pooled_ph = pooled_ph.view(-1, input_ids.size(-2),
                                            self.bert_model.config.hidden_size)
        max_pooled_ph, _ = torch.max(reshaped_pooled_ph, dim=1)
        if debug:
            print(f"reshaped_pooled_ph.size() = {reshaped_pooled_ph.size()}")
            print(f"max_pooled_ph.size() = {max_pooled_ph.size()}")
        # batch_size,num_of_choices, N_P, N_P-1,L1
        links = links_tokens['tokens']
        # batch_size,num_of_choices, N_P, N_P-1,L1
        link_type_ids = links_token_type_ids
        batch_size, num_of_choices, max_premise, c, link_len = links.size()
        links = links.view(-1, link_len)
        link_type_ids = link_type_ids.view(-1, link_len)
        link_mask = (links != 0).long()
        # batch_size*N_P* N_P-1,  CLS embedding
        _, pooled_links = self.link_model(input_ids=links,
                                          token_type_ids=link_type_ids,
                                          attention_mask=link_mask)

        # project the links
        projected_links = self._projection_layer(pooled_links)
        #projected_links = pooled_links
        # compute summary vector
        projected_links = projected_links.view(batch_size * num_of_choices,
                                               max_premise, max_premise - 1,
                                               -1)

        reduced_link_mask, _ = torch.max(link_mask.view(
            batch_size * num_of_choices, max_premise, max_premise - 1, -1),
                                         dim=-1,
                                         keepdim=False)

        # batch_size, N_p, CLS embedding
        link_pooled = torch.sum(
            replace_masked_values(projected_links,
                                  reduced_link_mask.unsqueeze(-1), -1e7), -2)

        # from batch_size, num_choices, max_premise, max_hypolen to ...
        coverage = coverage.view(batch_size * num_of_choices, max_premise, -1)
        if self._normalize_coverage:
            coverage_normalizer, _ = coverage.max(dim=1)
            coverage_normalizer[coverage_normalizer == 0] = 0.0001
            coverage = coverage / coverage_normalizer.unsqueeze(1)

        # batch_size*num_choice, hyp_len, projection_dim
        summary = torch.bmm(torch.transpose(coverage, 2, 1), link_pooled)

        if debug:
            print(f"coverage.size()={coverage.size()}")
            print(f"link_pooled.size()={link_pooled.size()}")
            print(f"summary.size()={summary.size()}")
        # batch_size, N_P
        cuda_device = self._get_prediction_device()
        coverage_mask = (coverage != 0).double().float().cuda(cuda_device)
        coverage_mask[coverage_mask == 0] = 0.0001
        coverage_mask, _ = torch.max(torch.sum(coverage_mask, -1),
                                     dim=-1,
                                     keepdim=False)

        link_avg_summary = torch.sum(summary, -2) / coverage_mask.unsqueeze(-1)

        # compute enhanced key
        pooled = torch.cat((max_pooled_ph, link_avg_summary), dim=-1)

        pooled = self._dropout(pooled)

        # apply classification layer
        logits = self._classification_layer(pooled)

        # shape: batch_size,num_choices
        reshaped_logits = logits.view(-1, num_of_choices)
        if debug:
            print(f"reshaped_logits = {reshaped_logits}")

        probs = torch.nn.functional.softmax(reshaped_logits, dim=-1)

        output_dict = {"logits": reshaped_logits, "probs": probs}

        if label is not None:
            loss = self._loss(reshaped_logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(reshaped_logits, label)

        return output_dict
Ejemplo n.º 31
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict
Ejemplo n.º 32
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            sentence_spans: torch.IntTensor = None,
            sent_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            q_type: torch.IntTensor = None,
            sp_mask: torch.IntTensor = None,
            coref_mask: torch.FloatTensor = None) -> Dict[str, torch.Tensor]:

        embedded_question = self._text_field_embedder(question)
        embedded_passage = self._text_field_embedder(passage)
        decoupled_passage, spans_mask = convert_sequence_to_spans(
            embedded_passage, sentence_spans)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()
        encodeded_decoupled_passage = \
            self._phrase_layer_sp(
                decoupled_passage, spans_mask.view(batch_size * num_spans, -1))
        context_output_sp = convert_span_to_sequence(
            embedded_passage, encodeded_decoupled_passage, spans_mask)

        ques_mask = util.get_text_field_mask(question).float()
        context_mask = util.get_text_field_mask(passage).float()

        ques_output_sp = self._phrase_layer_sp(embedded_question, ques_mask)

        modeled_passage_sp = self.qc_att_sp(context_output_sp, ques_output_sp,
                                            ques_mask)
        modeled_passage_sp = self.linear_2(modeled_passage_sp)
        modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp,
                                                     context_mask)
        # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim)
        # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(
            modeled_passage_sp, sentence_spans)
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        gate_logit = self._span_gate(spans_rep_sp, spans_mask)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()
        sent_mask = (sent_labels >= 0).long()
        sent_labels = sent_labels * sent_mask
        # print(sent_labels)
        # print(gate_logit.shape)
        # print(gate_logit)
        strong_sup_loss = torch.mean(-torch.log(
            torch.sum(F.softmax(gate_logit) *
                      sent_labels.float().view(batch_size, num_spans),
                      dim=-1) + 1e-10))

        # strong_sup_loss = F.nll_loss(F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
        #                              sent_labels.long().view(batch_size * num_spans), ignore_index=-1)

        gate = torch.argmax(gate_logit.view(batch_size, num_spans), -1)
        # gate = (gate >= 0.5).long().view(batch_size, num_spans)
        output_dict = {"gate": gate}

        loss = strong_sup_loss

        output_dict["loss"] = loss

        if metadata is not None:
            question_tokens = []
            passage_tokens = []
            sent_labels_list = []
            ids = []

            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])

            self._sent_metrics(gate, sent_labels)
            # print(self.get_prediction(gate, sent_labels).item())
            # print(self.get_prediction(gate, sent_labels).data)
            output_dict['predict'] = [
                self.get_prediction(gate, sent_labels).data
            ]
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['sent_labels'] = sent_labels_list
            output_dict['_id'] = ids
            # print(ids)

        return output_dict
Ejemplo n.º 33
0
    def forward_ensemble(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False) -> Dict[str, torch.Tensor]:
        """
        Sample 10 times and add them together
        """
        self.set_posterior_mean(True)
        most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)
        self.set_posterior_mean(False)
       
        subresults = [most_likely_output]
        for i in range(10):
           subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information))

        batch_size = len(subresults[0]["best_span"])

        best_span = bidut.merge_span_probs(subresults)
        
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
                        
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            context: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # the `context` is the concact of `question` and `passage`, so we just use `context`
        batch_size, num_of_passage_tokens = context['tokens'].size()

        # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of
        # start and end spans.
        embedded_passage = self._text_field_embedder(context)
        passage_length = embedded_passage.size(1)
        logits = self.qa_outputs(embedded_passage)
        start_logits, end_logits = logits.split(1, dim=-1)
        span_start_logits = start_logits.squeeze(-1)
        span_end_logits = end_logits.squeeze(-1)

        # Adding some masks with numerically stable values
        passage_mask = util.get_text_field_mask(passage).float()
        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            batch_size, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_start_probs = util.masked_softmax(span_start_logits,
                                               repeated_passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)
        span_end_probs = util.masked_softmax(span_end_logits,
                                             repeated_passage_mask)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict: Dict[str, Any] = {}

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end],
                                                     -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on span qa and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])

                passage_words = metadata[i]["paragraph_words"]
                answer_offset = metadata[i]["answer_offset"]
                tok_to_word_index = metadata[i]["tok_to_word_index"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_position = tok_to_word_index[predicted_span[0] -
                                                   answer_offset]
                end_position = tok_to_word_index[predicted_span[1] -
                                                 answer_offset]
                best_span_str = " ".join(
                    passage_words[start_position:end_position + 1])
                output_dict["best_span_str"].append(best_span_str)
                answer_text = metadata[i].get("answer_text", [])
                if answer_text:
                    answer_text = [answer_text]
                    self._span_qa_metrics(best_span_str, answer_text)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens

        return output_dict
Ejemplo n.º 35
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Ejemplo n.º 36
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            token_type_ids: torch.LongTensor,
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField`` (that has a bert-pretrained token indexer)
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        debug = False
        # shape: batch_size, num_choices, max_len
        input_ids = tokens["tokens"]

        input_mask = (input_ids != 0).long()

        # shape: batch_size*num_choices, max_len
        flat_input_ids = input_ids.view(-1, input_ids.size(-1))
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
        flat_attention_mask = input_mask.view(-1, input_mask.size(-1))

        if debug:
            print(f"flat_input_ids = {flat_input_ids}")
            print(f"flat_token_type_ids = {flat_token_type_ids}")
            print(f"flat_attention_mask = {flat_attention_mask}")

        # shape: batch_size*num_choices, hidden_dim
        _, pooled = self.bert_model(input_ids=flat_input_ids,
                                    token_type_ids=flat_token_type_ids,
                                    attention_mask=flat_attention_mask)
        if debug:
            print(f"pooled = {pooled}")

        pooled = self._dropout(pooled)

        if debug:
            print(f"pooled = {pooled}")

        # apply classification layer
        # shape: batch_size*num_choices, 1
        logits = self._classification_layer(pooled)
        if debug:
            print(f"logits = {logits}")

        # shape: batch_size,num_choices
        reshaped_logits = logits.view(-1, input_ids.size(1))

        if debug:
            print(f"reshaped_logits = {reshaped_logits}")

        probs = torch.nn.functional.softmax(reshaped_logits, dim=-1)

        output_dict = {"logits": reshaped_logits, "probs": probs}

        if label is not None:
            loss = self._loss(reshaped_logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(reshaped_logits, label)

        return output_dict
    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(embedded_passage)

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask)
        )
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask)
        )

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True
        )
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True
        )
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat(
                [
                    encoded_passage,
                    passage_question_vectors,
                    encoded_passage * passage_question_vectors,
                    encoded_passage * passage_passage_vectors,
                ],
                dim=-1,
            )
        )

        modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]

        for _ in range(3):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask)
            )
            modeled_passage_list.append(modeled_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)
            )
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)
            )
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])
                passage_str = metadata[i]["original_passage"]
                offsets = metadata[i]["token_offsets"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict["best_span_str"].append(best_span_string)
                answer_texts = metadata[i].get("answer_texts", [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict
Ejemplo n.º 38
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = True) -> Dict[str, torch.Tensor]:
        """
        WE LOAD THE MODELS ONE INTO GPU  ONE AT A TIME !!!
        """
        
        subresults = []
        for submodel in self.submodels:
            submodel.to(device = submodel.cf_a.device)
            subres = submodel(question, passage, span_start, span_end, metadata, get_sample_level_information)
            submodel.to(device = torch.device("cpu"))
            subresults.append(subres)

        batch_size = len(subresults[0]["best_span"])

        best_span = merge_span_probs(subresults)
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output