示例#1
0
    def _compute_span_representations(
            self, text_embeddings: torch.FloatTensor,
            text_mask: torch.FloatTensor, span_starts: torch.IntTensor,
            span_ends: torch.IntTensor) -> torch.FloatTensor:
        """
        Computes an embedded representation of every candidate span. This is a concatenation
        of the contextualized endpoints of the span, an embedded representation of the width of
        the span and a representation of the span's predicted head.

        Parameters
        ----------
        text_embeddings : ``torch.FloatTensor``, required.
            The embedded document of shape (batch_size, document_length, embedding_dim)
            over which we are computing a weighted sum.
        text_mask : ``torch.FloatTensor``, required.
            A mask of shape (batch_size, document_length) representing non-padding entries of
            ``text_embeddings``.
        span_starts : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans) representing the start of each span candidate.
        span_ends : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans) representing the end of each span candidate.
        Returns
        -------
        span_embeddings : ``torch.FloatTensor``
            An embedded representation of every candidate span with shape:
            (batch_size, num_spans, context_layer.get_output_dim() * 2 + embedding_size + feature_size)
        """
        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)

        # Shape: (batch_size, num_spans, encoding_dim)
        start_embeddings = util.batched_index_select(contextualized_embeddings,
                                                     span_starts.squeeze(-1))
        end_embeddings = util.batched_index_select(contextualized_embeddings,
                                                   span_ends.squeeze(-1))

        # Compute and embed the span_widths (strictly speaking the span_widths - 1)
        # Shape: (batch_size, num_spans, 1)
        span_widths = span_ends - span_starts
        # Shape: (batch_size, num_spans, encoding_dim)
        span_width_embeddings = self._span_width_embedding(
            span_widths.squeeze(-1))

        # Shape: (batch_size, document_length, 1)
        head_scores = self._head_scorer(contextualized_embeddings)

        # Shape: (batch_size, num_spans, embedding_dim)
        # Note that we used the original text embeddings, not the contextual ones here.
        attended_text_embeddings = self._create_attended_span_representations(
            head_scores, text_embeddings, span_ends, span_widths)
        # (batch_size, num_spans, context_layer.get_output_dim() * 2 + embedding_dim + feature_dim)
        span_embeddings = torch.cat([
            start_embeddings, end_embeddings, span_width_embeddings,
            attended_text_embeddings
        ], -1)
        return span_embeddings
示例#2
0
    def forward(
            self,
            text: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        text : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField`` or ``MultiLabelField``, a tensor of shape ``(batch_size, num_labels)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:
        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text = self.text_field_embedder(text)
        mask = util.get_text_field_mask(text)
        encoded_text = self.text_encoder(embedded_text, mask)
        pooled = self.pool(encoded_text, mask)

        hidden = self.classifier_feedforward(
            pooled)  # batch size x hidden size
        logits = self.prediction_layer(hidden)  # batch size x num labels

        # Reference: https://pytorch.org/docs/master/nn.html#sigmoid
        probabilities = torch.sigmoid(logits)  # batch size x num labels

        output_dict = {"logits": logits, "class_probs": probabilities}
        if label is not None:
            predictions = (logits.data > 0.0).long()
            label_data = label.squeeze(-1).data.long()
            self.micro_f1(predictions, label_data)
            output_dict["loss"] = self.loss(logits.squeeze(),
                                            label.squeeze(-1).float())
            for i, k in enumerate(self.label_f1.keys()):
                label_f1 = self.label_f1[k]
                label_f1(predictions[:, i], label[:, i].long())
        return output_dict
示例#3
0
    def forward( self,
                Orgquestion: Dict[str, torch.LongTensor],
                Relquestion: Dict[str, torch.LongTensor],
                 label: torch.IntTensor = None,
                 metadata: List[Dict[str, Any]] = None ) -> Dict[str, torch.Tensor]:

        embedded_premise = self._seq_dropout( self._text_field_embedder( Orgquestion ) )
        premise_mask = util.get_text_field_mask( Orgquestion )
        encoded_premise = self._encoder( embedded_premise, premise_mask )

        embedded_hypo = self._seq_dropout( self._text_field_embedder( Relquestion ) )
        hypo_mask = util.get_text_field_mask( Relquestion )
        encoded_hypo = self._encoder( embedded_hypo, hypo_mask )

        # 使用Neural Tensor Network提取interaction
        similarity_matrix = self._matching_layer( encoded_premise, encoded_hypo )
        pool_out = self._pool_layer( similarity_matrix ) # k-max pooling到固定维度
        label_out = self._output_feedforward( pool_out )
        label_logits=torch.sigmoid(label_out)
        output_dict = {"label_logits": label_logits}

        if label is not None:
            loss = self._loss(label_logits, label)
            for metric in self.metrics.values():
                metric( label_logits.squeeze(1), label.squeeze(1) )
            output_dict["loss"] = loss

        return output_dict
示例#4
0
    def forward(self,
                features: torch.Tensor,
                metadata: List[Dict[str, Any]] = None,
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        features: torch.Tensor,
            From a ``FloatField`` over the overlap features computed by the SimpleOverlapReader
        metadata: List[Dict[str, Any]]
            Metadata information
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        label_logits = self.linear_mlp(features)
        label_probs = torch.nn.functional.softmax(label_logits)
        output_dict = {"label_logits": label_logits, "label_probs": label_probs}
        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss
        return output_dict
示例#5
0
    def forward(
            self,
            Orgquestion: Dict[str, torch.LongTensor],
            Relquestion: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        embedded_premise = self._seq_dropout(
            self._text_field_embedder(Orgquestion))
        embedded_hypo = self._seq_dropout(
            self._text_field_embedder(Relquestion))

        # 计算词向量的dot-product相似度
        similarity_matrix = torch.unsqueeze(self._matching_layer(
            embedded_premise, embedded_hypo),
                                            dim=1)
        conv_out = self._inference_encoder(
            similarity_matrix)  # 使用堆叠卷积层处理interaction
        pool_out = self._pool_layer(conv_out)  # dynamic pooling到固定维度

        label_logits = self._output_feedforward(pool_out)
        label_logits = torch.sigmoid(label_logits)
        output_dict = {"label_logits": label_logits}

        if label is not None:
            loss = self._loss(label_logits, label)
            for metric in self.metrics.values():
                metric(label_logits.squeeze(1), label.squeeze(1))
            output_dict["loss"] = loss

        return output_dict
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        choices_list: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``
        choices_list : Dict[str, torch.LongTensor]
            From a ``List[TextField]``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of each choice being the correct answer.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of each choice being the correct answer.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        encoded_choices_aggregated = embed_encode_and_aggregate_list_text_field(
            choices_list, self._text_field_embedder, self._embeddings_dropout,
            self._choice_encoder, self._choice_aggregate)  # bs, choices, hs

        encoded_question_aggregated = embed_encode_and_aggregate_text_field(
            question, self._text_field_embedder, self._embeddings_dropout,
            self._question_encoder, self._question_aggregate)  # bs, hs

        q_to_choices_att = self._matrix_attention_question_to_choice(
            encoded_question_aggregated.unsqueeze(1),
            encoded_choices_aggregated).squeeze()

        label_logits = q_to_choices_att
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
示例#7
0
    def forward(
        self,  # type: ignore
        hypothesis0: Dict[str, torch.LongTensor],
        hypothesis1: Dict[str, torch.LongTensor],
        hypothesis2: Dict[str, torch.LongTensor],
        hypothesis3: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        logits = []
        for tokens in [hypothesis0, hypothesis1, hypothesis2, hypothesis3]:
            if isinstance(self.text_field_embedder, ElmoTokenEmbedder):
                self.text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states(
                )

            embedded_text_input = self.embedding_dropout(
                self.text_field_embedder(tokens))
            mask = get_text_field_mask(tokens)

            batch_size, sequence_length, _ = embedded_text_input.size()

            encoded_text = self.encoder(embedded_text_input, mask)

            logits.append(self.output_prediction(encoded_text.max(1)[0]))

        logits = torch.cat(logits, -1)
        class_probabilities = F.softmax(logits, dim=-1).view([batch_size, 4])
        output_dict = {
            "label_logits": logits,
            "label_probs": class_probabilities
        }

        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            self._accuracy(logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
示例#8
0
    def forward(
        self,  # type: ignore
        qa_pairs: Dict[str, torch.LongTensor],
        answer_index: torch.IntTensor = None,
        metadata: List[Dict[str, Any]  # pylint:disable=unused-argument
                       ] = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        qa_pairs : Dict[str, torch.LongTensor]
            From a ``TextField`` (that has a bert-pretrained token indexer)
        answer_index : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        nsp_logits = self.bert(qa_pairs, num_wrapping_dims=1)
        # nsp_logits: batch, 5, 2
        nsp_probs = torch.nn.functional.softmax(nsp_logits, dim=-1)
        # nsp_probs: batch, 5, 2
        nsp_pos_probs = nsp_probs[..., 0]
        # nsp_pos_probs = self._linear(nsp_probs).squeeze(-1)

        label_probs = torch.nn.functional.softmax(nsp_pos_probs, dim=-1)
        # label_score: batch, 5

        output_dict = {
            "nsp_logits": nsp_logits,
            "nsp_probs": nsp_probs,
            "label_probs": label_probs
        }

        if answer_index is not None:
            answer_index = answer_index.squeeze(-1)
            loss = self._loss(nsp_pos_probs, answer_index)
            self._accuracy(nsp_pos_probs, answer_index)
            output_dict["loss"] = loss
        return output_dict
示例#9
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """ Parameters
        ----------
        tokens: Dict[str, torch.LongTensor]
        From a ``TextField``.
        label: ``torch.IntTensor``, optional
        From an ``LabelField``.  This is what we are trying to predict.
        If this is given, we will compute a loss that gets included in the output dictionary.
        Returns
        -------
        An output dictionary consisting of the followings.
        logits: torch.FloatTensor
        A tensor of shape ``(batch_size, 2)`` representing unnormalised log
        probabilities of the evidence selection confidence.
        probs: torch.FloatTensor
        A tensor of shape ``(batch_size, 2)`` representing probabilities of the
        evidence selection confidence.
        loss: torch.FloatTensor, optional
        A scalar loss to be optimised. 
        """

        # batch, seq_len
        cls_hidden = self._bert(tokens)
        # batch, seq_len, hidden
        cls_hidden = cls_hidden[:, 0, :]
        cls_hidden = self._pooler(cls_hidden)
        # batch, hidden

        if self.dropout:
            cls_hidden = self.dropout(cls_hidden)

        # the final MLP -- apply dropout to input, and MLP applies to hidden
        logits = self._classifier(cls_hidden)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        output_dict = {"logits": logits, "probs": probs}

        if label is not None:
            label = label.squeeze(-1)
            loss = self._loss(logits, label)
            self._accuracy(logits, label)
            output_dict["loss"] = loss
        return output_dict
示例#10
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(embedded_passage)

        encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
                passage_question_similarity,
                question_mask,
                memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
                passage_question_similarity.transpose(1, 2),
                passage_mask,
                memory_efficient=True)
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
                torch.cat([encoded_passage, passage_question_vectors,
                           encoded_passage * passage_question_vectors,
                           encoded_passage * passage_passage_vectors],
                          dim=-1)
                )

        modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]

        for _ in range(3):
            modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
示例#11
0
文件: sciie.py 项目: almoslmi/yarx
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        metadata: List[Dict[str, Any]],
        doc_span_offsets: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        doc_truth_spans: torch.IntTensor = None,
        doc_spans_in_truth: torch.IntTensor = None,
        doc_relation_labels: torch.Tensor = None,
        truth_spans: List[Set[Tuple[int, int]]] = None,
        doc_relations=None,
        doc_ner_labels: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:  # add matrix from datareader
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        doc_ner_labels : ``torch.IntTensor``.
            A tensor of shape # TODO,
            ...
        doc_span_offsets : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_truth_spans : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, 1),
            ...
        doc_spans_in_truth : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_relation_labels : ``torch.Tensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans),
            ...

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        batch_size = len(spans)
        document_length = text_embeddings.size(1)
        max_sentence_length = max(
            len(sentence) for document in metadata
            for sentence in document['doc_tokens'])
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # TODO features dropout
        # Shape: (batch_size, num_spans, embedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_relex_spans_to_keep = int(
            math.floor(self._relex_spans_per_word * max_sentence_length))

        # Shapes:
        # (batch_size, num_spans_to_keep, span_dim),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # Shape: (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = dict()

        output_dict["top_spans"] = top_spans
        output_dict["antecedent_indices"] = valid_antecedent_indices
        output_dict["predicted_antecedents"] = predicted_antecedents

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]

        # Shape: (,)
        loss = 0

        # Shape: (batch_size, max_sentences, max_spans)
        doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float()
        # Shape: (batch_size, max_sentences, num_spans, span_dim)
        doc_span_embeddings = util.batched_index_select(
            span_embeddings,
            doc_span_offsets.squeeze(-1).long().clamp(min=0))

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        pruned = self._relex_mention_pruner(
            doc_span_embeddings,
            doc_span_mask,
            num_items_to_keep=num_relex_spans_to_keep,
            pass_through=['num_items_to_keep'])
        (top_relex_span_embeddings, top_relex_span_mask,
         top_relex_span_indices, top_relex_span_mention_scores) = pruned

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        top_relex_span_mask = top_relex_span_mask.unsqueeze(-1)

        # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2)  # TODO do we need for a mask?
        doc_spans = util.batched_index_select(
            spans,
            doc_span_offsets.clamp(0).squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2)
        top_relex_spans = nd_batched_index_select(doc_spans,
                                                  top_relex_span_indices)

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep).
        (relex_span_pair_embeddings,
         relex_span_pair_mask) = self._compute_relex_span_pair_embeddings(
             top_relex_span_embeddings, top_relex_span_mask.squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels)
        relex_scores = self._compute_relex_scores(
            relex_span_pair_embeddings, top_relex_span_mention_scores)
        output_dict['relex_scores'] = relex_scores
        output_dict['top_relex_spans'] = top_relex_spans

        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)
            antecedent_labels_ = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long(
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability x to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs)
            negative_marginal_log_likelihood *= top_span_mask.squeeze(
                -1).float()
            negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum(
            )

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            coref_loss = negative_marginal_log_likelihood
            output_dict['coref_loss'] = coref_loss
            loss += self._loss_coref_weight * coref_loss

        if doc_relations is not None:

            # The adjacency matrix for relation extraction is very sparse.
            # As it is not just sparse, but row/column sparse (only few
            # rows and columns are non-zero and in that case these rows/columns
            # are not sparse), we implemented our own matrix for the case.
            # Here we have indices of truth spans and mapping, using which
            # we map prediction matrix on truth matrix.
            # TODO Add teacher forcing support.

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep),
            relative_indices = top_relex_span_indices
            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1),
            compressed_indices = nd_batched_padded_index_select(
                doc_spans_in_truth, relative_indices)

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans)
            gold_pruned_rows = nd_batched_padded_index_select(
                doc_relation_labels,
                compressed_indices.squeeze(-1),
                padding_value=0)
            gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3,
                                                        2).contiguous()

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep)
            gold_pruned_matrices = nd_batched_padded_index_select(
                gold_pruned_rows,
                compressed_indices.squeeze(-1),
                padding_value=0)  # pad with epsilon
            gold_pruned_matrices = gold_pruned_matrices.permute(
                0, 1, 3, 2).contiguous()

            # TODO log_mask relex score before passing
            relex_loss = nd_cross_entropy_with_logits(relex_scores,
                                                      gold_pruned_matrices,
                                                      relex_span_pair_mask)
            output_dict['relex_loss'] = relex_loss

            self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2),
                                       truth_spans)
            self._compute_relex_metrics(output_dict, doc_relations)

            loss += self._loss_relex_weight * relex_loss

        if doc_ner_labels is not None:
            # Shape: (batch_size, max_sentences, num_spans, num_ner_classes)
            ner_scores = self._ner_scorer(doc_span_embeddings)
            output_dict['ner_scores'] = ner_scores

            ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels,
                                                    doc_span_mask)
            output_dict['ner_loss'] = ner_loss
            loss += self._loss_ner_weight * ner_loss

        if not isinstance(loss, int):  # If loss is not yet modified
            output_dict["loss"] = loss

        return output_dict
示例#12
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
示例#13
0
    def forward(
            self,  # type: ignore
            metadata: Dict,
            tokens: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField`` (that has a bert-pretrained token indexer)
        span_start : torch.IntTensor, optional (default = None)
            A tensor of shape (batch_size, 1) which contains the start_position of the answer
            in the passage, or 0 if impossible. This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : torch.IntTensor, optional (default = None)
            A tensor of shape (batch_size, 1) which contains the end_position of the answer
            in the passage, or 0 if impossible. This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        start_probs: torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        end_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        best_span:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        input_ids = tokens[self._index]
        token_type_ids = tokens[f"{self._index}-type-ids"]
        input_mask = (input_ids != 0).long()

        # 1. Build model here
        # shape bert_embeddings: (batch_size, seq_len, 768)
        # shape bert_embeddings: (batch_size, 768)
        bert_embeddings, pooled_output = self.bert_model(
            input_ids, token_type_ids, attention_mask=input_mask)
        bert_embeddings = self.dropout(bert_embeddings)
        start_scores = self.start_linear(bert_embeddings)
        end_scores = self.end_linear(bert_embeddings)
        #shape: (batch, seq_len)
        start_logits = start_scores.squeeze()
        end_logits = end_scores.squeeze()

        # mask scores, so that only the context is considered
        # question mask: in token_type_ids the context has 1s and the question 0s.
        #  To speed up training, replace 0s with negative infinity
        #question_mask = token_type_ids.clone().float().log()
        # question_mask = token_type_ids.clone().float()
        question_mask = (token_type_ids.float() - 1) * 1000000 + 1
        start_logits = start_logits * question_mask
        end_logits = end_logits * question_mask

        start_probs = softmax(start_logits)
        end_probs = softmax(end_logits)

        output_dict = {}

        if span_start is not None:
            #start_loss = self.loss_function(start_logits, span_start.squeeze())
            start_loss = self.loss_function(start_probs, span_start.squeeze())
            #end_loss = self.loss_function(end_logits, span_end.squeeze())
            end_loss = self.loss_function(end_probs, span_end.squeeze())

            self._span_start_accuracy(start_logits, span_start.squeeze(-1))
            self._span_end_accuracy(end_logits, span_end.squeeze(-1))

            # 2. Compute start_position and end_position and then get the best span
            # using allennlp.models.reading_comprehension.util.get_best_span()
            loss = (start_loss + end_loss) / 2
            # 4. Compute loss and accuracies. You should compute at least:
            # span_start accuracy, span_end accuracy and full span accuracy.

            # UNCOMMENT THIS LINE
            output_dict["loss"] = loss
            output_dict["_span_start_accuracy"] = self._span_start_accuracy
            output_dict["_span_end_accuracy"] = self._span_end_accuracy

            # 5. Optionally you can compute the official squad metrics (exact match, f1).
            # Instantiate the metric object in __init__ using allennlp.training.metrics.SquadEmAndF1()
            # When you call it, you need to give it the word tokens of the span (implement and call decode() below)
            # and the gold tokens found in metadata[i]['answer_texts']

        return output_dict
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,
                                                   projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "h2p_attention": h2p_attention,
            "p2h_attention": p2h_attention
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            sentence_spans: torch.IntTensor = None,
            sent_labels: torch.IntTensor = None,
            q_type: torch.IntTensor = None,
            sp_mask: torch.IntTensor = None,
            coref_mask: torch.FloatTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        embedded_question = self._text_field_embedder(question)
        embedded_passage = self._text_field_embedder(passage)
        ques_mask = util.get_text_field_mask(question).float()
        context_mask = util.get_text_field_mask(passage).float()

        ques_output = self._dropout(
            self._phrase_layer(embedded_question, ques_mask))
        context_output = self._dropout(
            self._phrase_layer(embedded_passage, context_mask))

        modeled_passage, qc_score = self.qc_att(context_output, ques_output,
                                                ques_mask)
        modeled_passage = self.linear_1(modeled_passage)
        modeled_passage = self._modeling_layer(modeled_passage, context_mask)

        batch_size = modeled_passage.size()[0]
        output_start = self._span_start_encoder(modeled_passage, context_mask)
        span_start_logits = self.linear_start(output_start).squeeze(
            2) - 1e30 * (1 - context_mask)
        output_end = torch.cat([modeled_passage, output_start], dim=2)
        output_end = self._span_end_encoder(output_end, context_mask)
        span_end_logits = self.linear_end(output_end).squeeze(
            2) - 1e30 * (1 - context_mask)

        output_type = torch.cat([modeled_passage, output_end, output_start],
                                dim=2)
        output_type = torch.max(output_type, 1)[0]
        predict_type = self.linear_type(output_type)
        type_predicts = torch.argmax(predict_type, 1)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span": best_span,
            "qc_score": qc_score
        }

        # Compute the loss for training.
        if span_start is not None:
            try:
                start_loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, None),
                    span_start.squeeze(-1))
                end_loss = nll_loss(
                    util.masked_log_softmax(span_end_logits, None),
                    span_end.squeeze(-1))
                type_loss = nll_loss(
                    util.masked_log_softmax(predict_type, None), q_type)
                loss = start_loss + end_loss + type_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['start_loss'](start_loss)
                self._loss_trackers['end_loss'](end_loss)
                self._loss_trackers['type_loss'](type_loss)
                output_dict["loss"] = loss

            except RuntimeError:
                print('\n meta_data:', metadata)
                print(span_start_logits.shape)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            token_spans_sp = []
            token_spans_sent = []
            sent_labels_list = []
            coref_clusters = []
            ids = []
            count_yes = 0
            count_no = 0
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                token_spans_sp.append(metadata[i]['token_spans_sp'])
                token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                coref_clusters.append(metadata[i]['coref_clusters'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                if type_predicts[i] == 1:
                    best_span_string = 'yes'
                    count_yes += 1
                elif type_predicts[i] == 2:
                    best_span_string = 'no'
                    count_no += 1
                else:
                    predicted_span = tuple(best_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    best_span_string = passage_str[start_offset:end_offset]

                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['token_spans_sp'] = token_spans_sp
            output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['coref_clusters'] = coref_clusters
            output_dict['_id'] = ids

        return output_dict
示例#16
0
    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat(
            [
                encoded_passage,
                passage_question_vectors,
                encoded_passage * passage_question_vectors,
                encoded_passage * tiled_question_passage_vector,
            ],
            dim=-1,
        )

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat(
            [
                final_merged_passage,
                modeled_passage,
                tiled_start_representation,
                modeled_passage * tiled_start_representation,
            ],
            dim=-1,
        )
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end],
                                                     -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            token_offsets = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])
                token_offsets.append(metadata[i]["token_offsets"])
                passage_str = metadata[i]["original_passage"]
                offsets = metadata[i]["token_offsets"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict["best_span_str"].append(best_span_string)
                answer_texts = metadata[i].get("answer_texts", [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
            output_dict["token_offsets"] = token_offsets
        return output_dict
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
示例#18
0
    def forward(self,
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata=None) -> Dict[str, torch.Tensor]:

        # shape: B x Tq x E
        embedded_question = self._embedder(question)
        embedded_passage = self._embedder(passage)

        batch_size = embedded_question.size(0)
        total_passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        # shape: B x T x 2H
        encoded_question = self._dropout(
            self._question_encoder(embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._passage_encoder(embedded_passage, passage_mask))
        passage_mask = passage_mask.float()
        question_mask = question_mask.float()

        encoding_dim = encoded_question.size(-1)

        # shape: B x 2H
        if encoded_passage.is_cuda:
            cuda_device = encoded_passage.get_device()
            gru_hidden = Variable(
                torch.zeros(batch_size, encoding_dim).cuda(cuda_device))
        else:
            gru_hidden = Variable(torch.zeros(batch_size, encoding_dim))

        question_awared_passage = []
        for timestep in range(total_passage_length):
            # shape: B x Tq = attention(B x 2H, B x Tq x 2H)
            attn_weights = self._question_attention_for_passage(
                encoded_passage[:, timestep, :], encoded_question,
                question_mask)
            # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq)
            attended_question = util.weighted_sum(encoded_question,
                                                  attn_weights)
            # shape: B x 4H
            passage_question_combined = torch.cat(
                [encoded_passage[:, timestep, :], attended_question], dim=-1)
            # shape: B x 4H
            gate = F.sigmoid(self._gate(passage_question_combined))
            gru_input = gate * passage_question_combined
            # shape: B x 2H
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            question_awared_passage.append(gru_hidden)

        # shape: B x T x 2H
        # question aware passage representation v_P
        question_awared_passage = torch.stack(question_awared_passage, dim=1)

        self_attended_passage = []
        for timestep in range(total_passage_length):
            attn_weights = self._passage_self_attention(
                question_awared_passage[:, timestep, :],
                question_awared_passage, passage_mask)
            attended_passage = util.weighted_sum(question_awared_passage,
                                                 attn_weights)
            input_combined = torch.cat(
                [question_awared_passage[:, timestep, :], attended_passage],
                dim=-1)
            gate = F.sigmoid(self._self_gate(input_combined))
            gru_input = gate * input_combined
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            self_attended_passage.append(gru_hidden)

        self_attended_passage = torch.stack(self_attended_passage, dim=1)

        # compute question vector r_Q
        # shape: B x T = attention(B x 2H, B x T x 2H)
        v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim)
        attn_weights = self._question_attention_for_question(
            v_r_Q_tiled, encoded_question, question_mask)
        # shape: B x 2H
        r_Q = util.weighted_sum(encoded_question, attn_weights)
        # shape: B x T = attention(B x 2H, B x T x 2H)
        span_start_logits = self._passage_attention_for_answer(
            r_Q, self_attended_passage, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_start_log_probs = util.masked_log_softmax(span_start_logits,
                                                       passage_mask)
        # shape: B x 2H
        c_t = util.weighted_sum(self_attended_passage, span_start_probs)
        # shape: B x 2H
        h_1 = self._dropout(self._answer_net(c_t, r_Q))

        span_end_logits = self._passage_attention_for_answer(
            h_1, self_attended_passage, passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_end_log_probs = util.masked_log_softmax(span_end_logits,
                                                     passage_mask)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        #num_passages = passages_length.size(1)
        #acc = Variable(torch.zeros(batch_size, num_passages + 1)).cuda(cuda_device).long()

        #acc[:, 1:num_passages+1] = torch.cumsum(passages_length, dim=1)

        #g_batch = []
        #for b in range(batch_size):
        #    g = []
        #    for i in range(num_passages):
        #        if acc[b, i+1].data[0] > acc[b, i].data[0]:
        #            attn_weights = self._passage_attention_for_ranking(r_Q[b:b+1], question_awared_passage[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0], :], passage_mask[b:b+1, acc[b, i].data[0]: acc[b, i+1].data[0]])
        #            r_P = util.weighted_sum(question_awared_passage[b:b+1, acc[b, i].data[0]:acc[b, i+1].data[0], :], attn_weights)
        #            question_passage_combined = torch.cat([r_Q[b:b+1], r_P], dim=-1)
        #            gi = self._dropout(self._match_layer_2(F.tanh(self._dropout(self._match_layer_1(question_passage_combined)))))
        #            g.append(gi)
        #        else:
        #            g.append(Variable(torch.zeros(1, 1)).cuda(cuda_device))
        #    g = torch.cat(g, dim=1)
        #    g_batch.append(g)

        #t2 = time.time()
        #g = torch.cat(g_batch, dim=0)
        output_dict = {}
        if span_start is not None:
            AP_loss = F.nll_loss(span_start_log_probs, span_start.squeeze(-1)) +\
                F.nll_loss(span_end_log_probs, span_end.squeeze(-1))
            #PR_loss = F.nll_loss(passage_log_probs, correct_passage.squeeze(-1))
            #loss = self._r * AP_loss + self._r * PR_loss
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict['loss'] = AP_loss

        _, max_start = torch.max(span_start_probs, dim=1)
        _, max_end = torch.max(span_end_probs, dim=1)
        #t3 = time.time()
        output_dict['span_start_idx'] = max_start
        output_dict['span_end_idx'] = max_end
        #t4 = time.time()
        #global ITE
        #ITE += 1
        #if (ITE % 100 == 0):
        #    print(" gold %i:%i|predicted %i:%i" %(span_start.squeeze(-1)[0], span_end.squeeze(-1)[0], max_start.data[0], max_end.data[0]))
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens

        #t5 = time.time()
        #print("Total: %.5f" % (t5-t0))
        #print("Batch processing 1: %.5f" % (t2-t1))
        #print("Batch processing 2: %.5f" % (t4-t3))
        return output_dict
示例#19
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            get_sample_level_information=True) -> Dict[str, torch.Tensor]:
        """
        WE LOAD THE MODELS ONE INTO GPU  ONE AT A TIME !!!
        """

        subresults = []
        for submodel in self.submodels:
            submodel.to(device=submodel.cf_a.device)
            subres = submodel(question, passage, span_start, span_end,
                              metadata, get_sample_level_information)
            submodel.to(device=torch.device("cpu"))
            subresults.append(subres)

        batch_size = len(subresults[0]["best_span"])

        best_span = merge_span_probs(subresults)
        output = {
            "best_span": best_span,
            "best_span_str": [],
            "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []

        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(
                            best_span_string, answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range(batch_size):

                span_start_probs = sum(
                    subresult['span_start_probs']
                    for subresult in subresults) / len(subresults)
                span_end_probs = sum(
                    subresult['span_end_probs']
                    for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i], :],
                                           span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i], :],
                                         span_end.squeeze(-1)[[i]])

                output["span_start_sample_loss"].append(
                    float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(
                    float(span_end_loss.detach().cpu().numpy()))
        return output
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_premise = self._embeddings_dropout(embedded_premise)

        embedded_hypothesis = self._text_field_embedder(hypothesis)
        embedded_hypothesis = self._embeddings_dropout(embedded_hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)

        embedded_premise = seq2vec_seq_aggregate(
            embedded_premise, premise_mask, self._premise_aggregate,
            self._premise_encoder.is_bidirectional(), 1)

        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        embedded_hypothesis = seq2vec_seq_aggregate(
            embedded_hypothesis, hypothesis_mask, self._hypothesis_aggregate,
            self._premise_encoder.is_bidirectional(), 1)

        aggregate_input = torch.cat([
            embedded_premise, embedded_hypothesis,
            torch.abs(embedded_hypothesis - embedded_premise),
            embedded_hypothesis * embedded_hypothesis
        ],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            labels = label.long().view(-1)
            loss = self._loss(label_logits, labels)
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
示例#21
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)


        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        # answer_len for masking
        answer_len = [len(elem['answer_texts']) for elem in metadata] if metadata is not None else []
        if answer_len:
            mask = torch.zeros((batch_size, max(answer_len), 2)).long()
            for index, length in enumerate(answer_len):
                mask[index, :length] = 1
        else:
            mask = None

        best_span = self.get_best_span(span_start_logits, span_end_logits, answer_len)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            span_start = span_start.squeeze(-1) #batch X max_answer_L
            span_end = span_end.squeeze(-1) #batch X max_answer_L

            # TODO answer padding needs to be ignored
            step = 0
            span_start_1D = span_start[ : , step:step + 1] #batch X 1 
            span_end_1D = span_end[ : , step:step + 1] #batch X 1 
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start_1D.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start_1D.squeeze(-1)) #TODO
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end_1D.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end_1D.squeeze(-1)) #TODO
            # self._span_accuracy(best_span, torch.stack([span_start_1D, span_end_1D], -1))#TODO

            for step in range(1, span_start.size(1)):
                span_start_1D = span_start[ : , step:step + 1] #batch X 1 
                span_end_1D = span_end[ : , step:step + 1] #batch X 1 
                loss += nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start_1D.squeeze(-1), ignore_index=-1)
                self._span_start_accuracy(span_start_logits, span_start_1D.squeeze(-1)) #TODO
                loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end_1D.squeeze(-1), ignore_index=-1)
                self._span_end_accuracy(span_end_logits, span_end_1D.squeeze(-1)) #TODO
                # self._span_accuracy(best_span, torch.stack([span_start_1D, span_end_1D], -1))#TODO
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1), mask)
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                best_span_strings = []
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_spans = tuple(best_span[i].data.cpu().numpy())
                for predicted_span in predicted_spans:
                    if predicted_span[0] == -1:
                        break
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    best_span_string = passage_str[start_offset:end_offset]
                    best_span_strings.append(best_span_string)
                output_dict['best_span_str'].append(best_span_strings)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_strings, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
示例#22
0
    def forward_ensemble(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = False) -> Dict[str, torch.Tensor]:
        """
        Sample 10 times and add them together
        """
        self.set_posterior_mean(True)
        most_likely_output = self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information)
        self.set_posterior_mean(False)
       
        subresults = [most_likely_output]
        for i in range(10):
           subresults.append(self.forward(question,passage,span_start,span_end,metadata,get_sample_level_information))

        batch_size = len(subresults[0]["best_span"])

        best_span = bidut.merge_span_probs(subresults)
        
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
                        
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output
示例#23
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.  The ending position is `exclusive`, so our
            :class:`~allennlp.data.dataset_readers.SquadReader` adds a special ending token to the
            end of the passage, to allow for the last token to be included in the answer span.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `exclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span end position (exclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self._get_best_span(span_start_logits, span_end_logits)

        output_dict = {"span_start_logits": span_start_logits,
                       "span_start_probs": span_start_probs,
                       "span_end_logits": span_end_logits,
                       "span_end_probs": span_end_probs,
                       "best_span": best_span}
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
        if metadata is not None and self._official_eval_dataset:
            output_dict['best_span_str'] = []
            for i in range(batch_size):
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                best_span_string = self._compute_official_metrics(metadata[i], predicted_span)  # type: ignore
                output_dict['best_span_str'].append(best_span_string)
        return output_dict
示例#24
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            for_training: bool = False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))

        ###################################### selection ######################################
        pbx = sigmoid(self.linear(embedded_passage).squeeze(2))
        print(' passage: ', passage)

        # # pby = self.we_selector(embedded_y1)

        # assert pbx.size() == passage['tokens'].size()

        # torch byte tesnor Variable of size (batch x len)
        selection_x = pbx.bernoulli().long()  #(pbx>=threshold).long()
        # # selection_y = pby.bernoulli().long()#(pby>=threshold).long()

        result_x = passage['tokens'].mul(
            selection_x
        )  #word ids that are selected; contains zeros where it's not selected (ony selected can be found by selected_x[selected_x!=0])
        char_result_x = passage['token_characters'] * selection_x.unsqueeze(
            2).repeat(1, 1, passage['token_characters'].size()[2])
        # result_y = sentence2.mul(selection_y)
        # print('result_x: ', result_x)

        selected_x, char_selected_x = helper.get_selected_tensor(
            result_x, char_result_x, pbx, passage['tokens'],
            passage['token_characters'],
            self.cuda_device)  #sentence1_len is a numpy array
        print(' passage size: ', passage['tokens'], ' char_passage size: ',
              passage['token_characters'], ' selected_x: ', selected_x,
              ' char_selected_x: ', char_selected_x)
        # selected_y, sentence2_len = helper.get_selected_tensor(result_y, pby, sentence2, sentence2_len_old, self.config.cuda) #sentence2_len is a numpy array

        logpz = zsum = zdiff = -1.0
        if for_training:
            mask1 = (
                passage['tokens'] !=
                self._vocab.get_token_index(DEFAULT_PADDING_TOKEN)).long()
            #     mask2 = (sentence2!=0).long()

            masked_selection_x = selection_x.mul(mask1)
            #     masked_selection_y =  selection_y.mul(mask2)

            #     #logpz (batch x len)
            logpx = -helper.binary_cross_entropy(
                pbx, selection_x.float().detach(), reduce=False
            )  #as reduce is not available for this version I am doing this code myself:
            #     logpy = -helper.binary_cross_entropy(pby, selection_y.float().detach(), reduce = False)
            assert logpx.size() == passage['tokens'].size()

            #     # batch
            logpx = logpx.mul(mask1.float()).sum(1)
            #     logpy = logpy.mul(mask2.float()).sum(1)
            logpz = logpx  #(logpx+logpy)
            #     # zsum = ##### same as sentence1_len #####T.sum(z, axis=0, dtype=theano.config.floatX)
            zdiff1 = (
                masked_selection_x[:, 1:] - masked_selection_x[:, :-1]
            ).abs().sum(
                1
            )  ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)
            #     zdiff2 = (masked_selection_y[:,1:]-masked_selection_y[:,:-1]).abs().sum(1)  ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)

            assert zdiff1.size()[0] == passage['tokens'].size()[0]
            #     assert logpz.size()[0] == sentence1.size()[0]

            zdiff = zdiff1  #+zdiff2

            xsum = masked_selection_x.sum(1)
            #     ysum = masked_selection_y.sum(1)
            zsum = xsum  #+ysum

            assert zsum.size()[0] == passage['tokens'].size()[0]

            assert logpz.dim() == zsum.dim()
            assert logpz.dim() == zdiff.dim()
        #     return selected_x, sentence1_len, selected_y, sentence2_len, logpz, zsum.float(), zdiff.float()

        passage['tokens'] = selected_x
        passage['token_characters'] = char_selected_x

        # print(' passage[tokens]: ', passage['tokens'], ' dim: ', passage['tokens'].dim())
        # print("selected_x: ", selected_x, ' dim: ', selected_x.dim())
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))

        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
示例#25
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis0: Dict[str, torch.LongTensor],
            hypothesis1: Dict[str, torch.LongTensor],
            hypothesis2: Dict[str, torch.LongTensor] = None,
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        hyps = [
            h for h in [hypothesis0, hypothesis1, hypothesis2] if h is not None
        ]
        if isinstance(self._text_field_embedder, ElmoTokenEmbedder):
            self._text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states(
            )

        embedded_premise = self._text_field_embedder(premise)

        embedded_hypotheses = []
        for hypothesis in hyps:
            if isinstance(self._text_field_embedder, ElmoTokenEmbedder):
                self._text_field_embedder._elmo._elmo_lstm._elmo_lstm.reset_states(
                )
            embedded_hypotheses.append(self._text_field_embedder(hypothesis))

        premise_mask = get_text_field_mask(premise).float()
        hypothesis_masks = [
            get_text_field_mask(hypothesis).float() for hypothesis in hyps
        ]
        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypotheses = [
                self.rnn_input_dropout(hyp) for hyp in embedded_hypotheses
            ]

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)

        label_logits = []
        for i, (embedded_hypothesis, hypothesis_mask) in enumerate(
                zip(embedded_hypotheses, hypothesis_masks)):
            encoded_hypothesis = self._encoder(embedded_hypothesis,
                                               hypothesis_mask)

            # Shape: (batch_size, premise_length, hypothesis_length)
            similarity_matrix = self._matrix_attention(encoded_premise,
                                                       encoded_hypothesis)

            # Shape: (batch_size, premise_length, hypothesis_length)
            p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
            # Shape: (batch_size, premise_length, embedding_dim)
            attended_hypothesis = weighted_sum(encoded_hypothesis,
                                               p2h_attention)

            # Shape: (batch_size, hypothesis_length, premise_length)
            h2p_attention = masked_softmax(
                similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
            # Shape: (batch_size, hypothesis_length, embedding_dim)
            attended_premise = weighted_sum(encoded_premise, h2p_attention)

            # the "enhancement" layer
            premise_enhanced = torch.cat([
                encoded_premise, attended_hypothesis, encoded_premise -
                attended_hypothesis, encoded_premise * attended_hypothesis
            ],
                                         dim=-1)
            hypothesis_enhanced = torch.cat([
                encoded_hypothesis, attended_premise, encoded_hypothesis -
                attended_premise, encoded_hypothesis * attended_premise
            ],
                                            dim=-1)

            # embedding -> lstm w/ do -> enhanced attention -> dropout_proj, only if ELMO -> ff proj -> lstm w/ do -> dropout -> ff 300 -> dropout -> output

            # add dropout here with ELMO

            # the projection layer down to the model dimension
            # no dropout in projection
            projected_enhanced_premise = self._projection_feedforward(
                premise_enhanced)
            projected_enhanced_hypothesis = self._projection_feedforward(
                hypothesis_enhanced)

            # Run the inference layer
            if self.rnn_input_dropout:
                projected_enhanced_premise = self.rnn_input_dropout(
                    projected_enhanced_premise)
                projected_enhanced_hypothesis = self.rnn_input_dropout(
                    projected_enhanced_hypothesis)
            v_ai = self._inference_encoder(projected_enhanced_premise,
                                           premise_mask)
            v_bi = self._inference_encoder(projected_enhanced_hypothesis,
                                           hypothesis_mask)

            # The pooling layer -- max and avg pooling.
            # (batch_size, model_dim)
            v_a_max, _ = replace_masked_values(v_ai,
                                               premise_mask.unsqueeze(-1),
                                               -1e7).max(dim=1)
            v_b_max, _ = replace_masked_values(v_bi,
                                               hypothesis_mask.unsqueeze(-1),
                                               -1e7).max(dim=1)

            v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1),
                                dim=1) / torch.sum(
                                    premise_mask, 1, keepdim=True)
            v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1),
                                dim=1) / torch.sum(
                                    hypothesis_mask, 1, keepdim=True)

            # Now concat
            # (batch_size, model_dim * 2 * 4)
            v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

            # the final MLP -- apply dropout to input, and MLP applies to output & hidden
            if self.dropout:
                v = self.dropout(v)

            output_hidden = self._output_feedforward(v)
            logit = self._output_logit(output_hidden)
            assert logit.size(-1) == 1
            label_logits.append(logit)

        label_logits = torch.cat(label_logits, -1)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
示例#26
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            sentence_spans: torch.IntTensor = None,
            sent_labels: torch.IntTensor = None,
            evd_chain_labels: torch.IntTensor = None,
            q_type: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        if self._sent_labels_src == 'chain':
            batch_size, num_spans = sent_labels.size()
            sent_labels_mask = (sent_labels >= 0).float()
            print("chain:", evd_chain_labels)
            # we use the chain as the label to supervise the gate
            # In this model, we only take the first chain in ``evd_chain_labels`` for supervision,
            # right now the number of chains should only be one too.
            evd_chain_labels = evd_chain_labels[:, 0].long()
            # build the gate labels. The dim is set to 1 + num_spans to account for the end embedding
            # shape: (batch_size, 1+num_spans)
            sent_labels = sent_labels.new_zeros((batch_size, 1 + num_spans))
            sent_labels.scatter_(1, evd_chain_labels, 1.)
            # remove the column for end embedding
            # shape: (batch_size, num_spans)
            sent_labels = sent_labels[:, 1:].float()
            # make the padding be -1
            sent_labels = sent_labels * sent_labels_mask + -1. * (
                1 - sent_labels_mask)

        # word + char embedding
        embedded_question = self._text_field_embedder(question)
        embedded_passage = self._text_field_embedder(passage)
        # mask
        ques_mask = util.get_text_field_mask(question).float()
        context_mask = util.get_text_field_mask(passage).float()

        # BiDAF for answer predicion
        ques_output = self._dropout(
            self._phrase_layer(embedded_question, ques_mask))
        context_output = self._dropout(
            self._phrase_layer(embedded_passage, context_mask))

        modeled_passage, _, qc_score = self.qc_att(context_output, ques_output,
                                                   ques_mask)

        modeled_passage = self._modeling_layer(modeled_passage, context_mask)

        # BiDAF for gate prediction
        ques_output_sp = self._dropout(
            self._phrase_layer_sp(embedded_question, ques_mask))
        context_output_sp = self._dropout(
            self._phrase_layer_sp(embedded_passage, context_mask))

        modeled_passage_sp, _, qc_score_sp = self.qc_att_sp(
            context_output_sp, ques_output_sp, ques_mask)

        modeled_passage_sp = self._modeling_layer_sp(modeled_passage_sp,
                                                     context_mask)

        # gate prediction
        # Shape(spans_rep): (batch_size * num_spans, max_batch_span_width, embedding_dim)
        # Shape(spans_mask): (batch_size, num_spans, max_batch_span_width)
        spans_rep_sp, spans_mask = convert_sequence_to_spans(
            modeled_passage_sp, sentence_spans)
        spans_rep, _ = convert_sequence_to_spans(modeled_passage,
                                                 sentence_spans)
        # Shape(gate_logit): (batch_size * num_spans, 2)
        # Shape(gate): (batch_size * num_spans, 1)
        # Shape(pred_sent_probs): (batch_size * num_spans, 2)
        # Shape(gate_mask): (batch_size, num_spans)
        #gate_logit, gate, pred_sent_probs = self._span_gate(spans_rep_sp, spans_mask)
        gate_logit, gate, pred_sent_probs, gate_mask, g_att_score = self._span_gate(
            spans_rep_sp, spans_mask, self._gate_self_attention_layer,
            self._gate_sent_encoder)
        batch_size, num_spans, max_batch_span_width = spans_mask.size()

        strong_sup_loss = F.nll_loss(
            F.log_softmax(gate_logit, dim=-1).view(batch_size * num_spans, -1),
            sent_labels.long().view(batch_size * num_spans),
            ignore_index=-1)

        gate = (gate >= 0.3).long()
        spans_rep = spans_rep * gate.unsqueeze(-1).float()
        attended_sent_embeddings = convert_span_to_sequence(
            modeled_passage_sp, spans_rep, spans_mask)

        modeled_passage = attended_sent_embeddings + modeled_passage

        self_att_passage = self._self_attention_layer(modeled_passage,
                                                      mask=context_mask)
        modeled_passage = modeled_passage + self_att_passage[0]
        self_att_score = self_att_passage[2]

        output_start = self._span_start_encoder(modeled_passage, context_mask)
        span_start_logits = self.linear_start(output_start).squeeze(
            2) - 1e30 * (1 - context_mask)
        output_end = torch.cat([modeled_passage, output_start], dim=2)
        output_end = self._span_end_encoder(output_end, context_mask)
        span_end_logits = self.linear_end(output_end).squeeze(
            2) - 1e30 * (1 - context_mask)

        output_type = torch.cat([modeled_passage, output_end, output_start],
                                dim=2)
        output_type = torch.max(output_type, 1)[0]
        # output_type = torch.max(self.rnn_type(output_type, context_mask), 1)[0]
        predict_type = self.linear_type(output_type)
        type_predicts = torch.argmax(predict_type, 1)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span": best_span,
            "pred_sent_labels": gate.view(batch_size,
                                          num_spans),  #[B, num_span]
            "gate_probs":
            pred_sent_probs[:, 1].view(batch_size, num_spans),  #[B, num_span]
        }
        if self._output_att_scores:
            if not qc_score is None:
                output_dict['qc_score'] = qc_score
            if not qc_score_sp is None:
                output_dict['qc_score_sp'] = qc_score_sp
            if not self_att_score is None:
                output_dict['self_attention_score'] = self_att_score
            if not g_att_score is None:
                output_dict['evd_self_attention_score'] = g_att_score

        print("sent label:")
        for b_label in np.array(sent_labels.cpu()):
            b_label = b_label == 1
            indices = np.arange(len(b_label))
            print(indices[b_label] + 1)
        # Compute the loss for training.
        if span_start is not None:
            try:
                start_loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, None),
                    span_start.squeeze(-1))
                end_loss = nll_loss(
                    util.masked_log_softmax(span_end_logits, None),
                    span_end.squeeze(-1))
                type_loss = nll_loss(
                    util.masked_log_softmax(predict_type, None), q_type)
                loss = start_loss + end_loss + type_loss + strong_sup_loss
                self._loss_trackers['loss'](loss)
                self._loss_trackers['start_loss'](start_loss)
                self._loss_trackers['end_loss'](end_loss)
                self._loss_trackers['type_loss'](type_loss)
                self._loss_trackers['strong_sup_loss'](strong_sup_loss)
                output_dict["loss"] = loss
            except RuntimeError:
                print('\n meta_data:', metadata)
                print(span_start_logits.shape)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['answer_texts'] = []
            question_tokens = []
            passage_tokens = []
            token_spans_sp = []
            token_spans_sent = []
            sent_labels_list = []
            evd_possible_chains = []
            ans_sent_idxs = []
            ids = []
            count_yes = 0
            count_no = 0
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                token_spans_sp.append(metadata[i]['token_spans_sp'])
                token_spans_sent.append(metadata[i]['token_spans_sent'])
                sent_labels_list.append(metadata[i]['sent_labels'])
                ids.append(metadata[i]['_id'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                if type_predicts[i] == 1:
                    best_span_string = 'yes'
                    count_yes += 1
                elif type_predicts[i] == 2:
                    best_span_string = 'no'
                    count_no += 1
                else:
                    predicted_span = tuple(best_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    best_span_string = passage_str[start_offset:end_offset]

                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                output_dict['answer_texts'].append(answer_texts)

                if answer_texts:
                    self._squad_metrics(best_span_string.lower(), answer_texts)

                # shift sentence indice back
                evd_possible_chains.append([
                    s_idx - 1
                    for s_idx in metadata[i]['evd_possible_chains'][0]
                    if s_idx > 0
                ])
                ans_sent_idxs.append(
                    [s_idx - 1 for s_idx in metadata[i]['ans_sent_idxs']])
            self._f1_metrics(pred_sent_probs, sent_labels.view(-1),
                             gate_mask.view(-1))
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
            output_dict['token_spans_sp'] = token_spans_sp
            output_dict['token_spans_sent'] = token_spans_sent
            output_dict['sent_labels'] = sent_labels_list
            output_dict['evd_possible_chains'] = evd_possible_chains
            output_dict['ans_sent_idxs'] = ans_sent_idxs
            output_dict['_id'] = ids

        return output_dict
示例#27
0
    def forward(self,
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                passages_length: torch.LongTensor = None,
                correct_passage: torch.LongTensor = None,
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        # shape: B x Tq x E
        embedded_question = self._embedder(question)
        embedded_passage = self._embedder(passage)

        batch_size = embedded_question.size(0)
        total_passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        # shape: B x T x 2H
        encoded_question = self._dropout(
            self._question_encoder(embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._passage_encoder(embedded_passage, passage_mask))
        passage_mask = passage_mask.float()
        question_mask = question_mask.float()

        encoding_dim = encoded_question.size(-1)

        # shape: B x 2H
        if encoded_passage.is_cuda:
            cuda_device = encoded_passage.get_device()
            gru_hidden = Variable(
                torch.zeros(batch_size, encoding_dim).cuda(cuda_device))
        else:
            gru_hidden = Variable(torch.zeros(batch_size, encoding_dim))

        question_awared_passage = []
        for timestep in range(total_passage_length):
            # shape: B x Tq = attention(B x 2H, B x Tq x 2H)
            attn_weights = self._question_attention_for_passage(
                encoded_passage[:, timestep, :], encoded_question,
                question_mask)
            # shape: B x 2H = weighted_sum(B x Tq x 2H, B x Tq)
            attended_question = util.weighted_sum(encoded_question,
                                                  attn_weights)
            # shape: B x 4H
            passage_question_combined = torch.cat(
                [encoded_passage[:, timestep, :], attended_question], dim=-1)
            # shape: B x 4H
            gate = F.sigmoid(self._gate(passage_question_combined))
            gru_input = gate * passage_question_combined
            # shape: B x 2H
            gru_hidden = self._dropout(self._gru_cell(gru_input, gru_hidden))
            question_awared_passage.append(gru_hidden)

        # shape: B x T x 2H
        # question aware passage representation v_P
        question_awared_passage = torch.stack(question_awared_passage, dim=1)

        # compute question vector r_Q
        # shape: B x T = attention(B x 2H, B x T x 2H)
        v_r_Q_tiled = self._v_r_Q.unsqueeze(0).expand(batch_size, encoding_dim)
        attn_weights = self._question_attention_for_question(
            v_r_Q_tiled, encoded_question, question_mask)
        # shape: B x 2H
        r_Q = util.weighted_sum(encoded_question, attn_weights)
        # shape: B x T = attention(B x 2H, B x T x 2H)
        span_start_logits = self._passage_attention_for_answer(
            r_Q, question_awared_passage, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_start_log_probs = util.masked_log_softmax(span_start_logits,
                                                       passage_mask)
        # shape: B x 2H
        c_t = util.weighted_sum(question_awared_passage, span_start_probs)
        # shape: B x 2H
        h_1 = self._dropout(self._answer_net(c_t, r_Q))

        span_end_logits = self._passage_attention_for_answer(
            h_1, question_awared_passage, passage_mask)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_end_log_probs = util.masked_log_softmax(span_end_logits,
                                                     passage_mask)

        num_passages = passages_length.size(1)
        acc = Variable(torch.zeros(batch_size,
                                   num_passages + 1)).cuda(cuda_device).long()

        acc[:, 1:num_passages + 1] = torch.cumsum(passages_length, dim=1)

        g_batch = []
        for b in range(batch_size):
            g = []
            for i in range(num_passages):
                if acc[b, i + 1].data[0] > acc[b, i].data[0]:
                    attn_weights = self._passage_attention_for_ranking(
                        r_Q[b:b + 1], question_awared_passage[b:b + 1, acc[
                            b, i].data[0]:acc[b, i + 1].data[0], :],
                        passage_mask[b:b + 1,
                                     acc[b, i].data[0]:acc[b, i + 1].data[0]])
                    r_P = util.weighted_sum(
                        question_awared_passage[b:b + 1, acc[b, i].data[0]:acc[
                            b, i + 1].data[0], :], attn_weights)
                    question_passage_combined = torch.cat([r_Q[b:b + 1], r_P],
                                                          dim=-1)
                    gi = self._dropout(
                        self._match_layer_2(
                            F.tanh(
                                self._dropout(
                                    self._match_layer_1(
                                        question_passage_combined)))))
                    g.append(gi)
                else:
                    g.append(Variable(torch.zeros(1, 1)).cuda(cuda_device))
            g = torch.cat(g, dim=1)
            g_batch.append(g)

        #t2 = time.time()
        g = torch.cat(g_batch, dim=0)
        passage_log_probs = F.log_softmax(g, dim=-1)

        output_dict = {}
        if span_start is not None:
            AP_loss = F.nll_loss(span_start_log_probs, span_start.squeeze(-1)) +\
                F.nll_loss(span_end_log_probs, span_end.squeeze(-1))
            PR_loss = F.nll_loss(passage_log_probs,
                                 correct_passage.squeeze(-1))
            loss = self._r * AP_loss + self._r * PR_loss
            output_dict['loss'] = loss

        _, max_start = torch.max(span_start_probs, dim=1)
        _, max_end = torch.max(span_end_probs, dim=1)
        #t3 = time.time()
        output_dict['span_start_idx'] = max_start
        output_dict['span_end_idx'] = max_end
        #t4 = time.time()
        global ITE
        ITE += 1
        if (ITE % 100 == 0):
            print(" gold %i:%i|predicted %i:%i" %
                  (span_start.squeeze(-1)[0], span_end.squeeze(-1)[0],
                   max_start.data[0], max_end.data[0]))

        #t5 = time.time()
        #print("Total: %.5f" % (t5-t0))
        #print("Batch processing 1: %.5f" % (t2-t1))
        #print("Batch processing 2: %.5f" % (t4-t3))
        return output_dict
示例#28
0
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        choices: Dict[str, torch.LongTensor],
        evidence: Dict[str, torch.LongTensor],
        answer_index: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        qa_pairs : Dict[str, torch.LongTensor]
            From a ``ListField``.
        answer_index : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is what we are trying to predict.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, question and choices for each instance
            in the batch. The length of this list should be the batch size, and each dictionary
            should have the keys ``qid``, ``question``, ``choices``, ``question_tokens`` and
            ``choices_tokens``.

        Returns
        -------
        An output dictionary consisting of the followings.

        qid : List[str]
            A list consisting of question ids.
        answer_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_options=5)`` representing unnormalised log
            probabilities of the choices.
        answer_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_options=5)`` representing probabilities of the
            choices.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # batch, seq_len -> batch, seq_len, emb
        question_hidden = self._bert(question)
        batch_size, emb_size = question_hidden.size(0), question_hidden.size(2)
        question_hidden = question_hidden[..., 0, :]  # batch, emb

        # batch, 5, seq_len -> batch, 5, seq_len, emb
        choice_hidden = self._bert(choices, num_wrapping_dims=1)
        choice_hidden = choice_hidden[..., 0, :]  # batch, 5, emb

        if 'first' in self.model_type:
            # batch, 5, evi_num, seq_len -> batch, 5, evi_num, seq_len, emb
            evidence_hidden = self._bert(evidence, num_wrapping_dims=2)
            # evi_num = evidence_hidden.size(2)
            # batch, 5, evi_num, emb
            evidence_hidden = evidence_hidden[..., 0, :]

        if self.dropout:
            question_hidden = self.dropout(question_hidden)
            choice_hidden = self.dropout(choice_hidden)
            if 'first' in self.model_type:
                evidence_hidden = self.dropout(evidence_hidden)

        if 'first' in self.model_type:
            evidence_summary = evidence_hidden[..., 0, :]

        question_hidden = question_hidden.unsqueeze(1).expand(
            batch_size, 5, emb_size)

        cls_hidden = torch.cat([question_hidden, choice_hidden], dim=-1)
        if 'first' in self.model_type:
            cls_hidden = torch.cat([cls_hidden, evidence_summary], dim=-1)

        # the final MLP -- apply dropout to input, and MLP applies to hidden
        answer_logits = self._classifier(cls_hidden).squeeze(-1)
        answer_probs = torch.nn.functional.softmax(answer_logits, dim=-1)
        qids = [m['qid'] for m in metadata]
        output_dict = {
            "answer_logits": answer_logits,
            "answer_probs": answer_probs,
            "qid": qids
        }

        if answer_index is not None:
            answer_index = answer_index.squeeze(-1)
            loss = self._loss(answer_logits, answer_index)
            self._accuracy(answer_logits, answer_index)
            output_dict["loss"] = loss
        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
		Parameters
		----------
		question : Dict[str, torch.LongTensor]
			From a ``TextField``.
		passage : Dict[str, torch.LongTensor]
			From a ``TextField``.  The model assumes that this passage contains the answer to the
			question, and predicts the beginning and ending positions of the answer within the
			passage.
		span_start : ``torch.IntTensor``, optional
			From an ``IndexField``.  This is one of the things we are trying to predict - the
			beginning position of the answer with the passage.  This is an `inclusive` token index.
			If this is given, we will compute a loss that gets included in the output dictionary.
		span_end : ``torch.IntTensor``, optional
			From an ``IndexField``.  This is one of the things we are trying to predict - the
			ending position of the answer with the passage.  This is an `inclusive` token index.
			If this is given, we will compute a loss that gets included in the output dictionary.
		metadata : ``List[Dict[str, Any]]``, optional
			If present, this should contain the question ID, original passage text, and token
			offsets into the passage for each instance in the batch.  We use this for computing
			official metrics using the official SQuAD evaluation script.  The length of this list
			should be the batch size, and each dictionary should have the keys ``id``,
			``original_passage``, and ``token_offsets``.  If you only want the best span string and
			don't care about official metrics, you can omit the ``id`` key.

		Returns
		-------
		An output dictionary consisting of:
		span_start_logits : torch.FloatTensor
			A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
			probabilities of the span start position.
		span_start_probs : torch.FloatTensor
			The result of ``softmax(span_start_logits)``.
		span_end_logits : torch.FloatTensor
			A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
			probabilities of the span end position (inclusive).
		span_end_probs : torch.FloatTensor
			The result of ``softmax(span_end_logits)``.
		best_span : torch.IntTensor
			The result of a constrained inference over ``span_start_logits`` and
			``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
			and each offset is a token index.
		loss : torch.FloatTensor, optional
			A scalar loss to be optimised.
		best_span_str : List[str]
			If sufficient metadata was provided for the instances in the batch, we also return the
			string from the original passage that the model thinks is the best answer to the
			question.
		"""
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask, dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, passage_length)
        question_passage_similarity = torch.transpose(
            passage_question_similarity, 1, 2)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask, dim=-1)
        # Shape: (batch_size, question_length, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)

        passage_gate = torch.unsqueeze(
            self._passage_similarity_function(encoded_passage,
                                              passage_question_vectors), -1)
        passage_fusion = self._passage_fusion_function(
            encoded_passage, passage_question_vectors)
        gated_passage = passage_gate * passage_fusion + (
            1 - passage_gate) * encoded_passage

        question_gate = torch.unsqueeze(
            self._question_similarity_function(encoded_question,
                                               question_passage_vector), -1)
        question_fusion = self._question_fusion_function(
            encoded_question, question_passage_vector)
        gated_question = question_gate * question_fusion + (
            1 - question_gate) * encoded_question

        passage_passage_similarity = self._self_matrix_attention(
            gated_passage, gated_passage)
        passage_passage_attention = util.masked_softmax(
            passage_passage_similarity, passage_mask, dim=-1)
        passage_passage_vector = util.weighted_sum(gated_passage,
                                                   passage_passage_attention)
        final_passage = self._fusion_function(gated_passage,
                                              passage_passage_vector)

        modeled_passage = self._dropout(
            self._passage_modeling_layer(final_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        span_logits = self._span_predictor(modeled_passage)

        modeled_question = self._question_modeling_layer(
            gated_question, question_lstm_mask)
        question_vector = self._question_encoding_layer(
            modeled_question, question_lstm_mask).unsqueeze(-1)
        span_start_logits = torch.bmm(self._span_start_weight(modeled_passage),
                                      question_vector).squeeze(-1)
        span_end_logits = torch.bmm(self._span_end_weight(modeled_passage),
                                    question_vector).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            device_id = util.get_device_of(span_start)
            weight = self._span_weight.cuda(
                device_id) if device_id >= 0 else self._span_weight
            arange_mask = util.get_range_vector(passage_length,
                                                util.get_device_of(span_start))
            span_mask = (arange_mask >= span_start) & (arange_mask <= span_end)
            span_loss = nll_loss(self._masked_log_softmax(
                span_logits, passage_mask).transpose(1, 2),
                                 span_mask.long(),
                                 weight=weight)
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss + span_loss / 2

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
示例#30
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        question_mask_temp = question_mask
        question_mask_temp = question_mask_temp.unsqueeze_(1)

        #New Question SA encoding
        device0 = torch.device('cuda:0')
        device1 = torch.device('cuda:1')
        device2 = torch.device('cuda:2')
        device3 = torch.device('cuda:3')
        sa_question_sim = self._dropout(
            self._sa_matrix_attention(device1, embedded_question, None, True))
        sa_question_att = util.masked_softmax(sa_question_sim.to(device1),
                                              question_mask_temp.to(device1))
        sa_encoded_question = util.weighted_sum(encoded_question.to(device1),
                                                sa_question_att.to(device1))

        sa_encoded_question = sa_encoded_question.to(device0)

        sa_passage_sim = self._dropout(
            self._sa_matrix_attention(device2, embedded_passage, None, True))
        sa_passage_att = util.masked_softmax(
            sa_passage_sim.to(device1),
            passage_mask.clone().unsqueeze_(1).to(device1))
        sa_encoded_passage = util.weighted_sum(encoded_passage.to(device1),
                                               sa_passage_att.to(device1))
        sa_encoded_passage = sa_encoded_passage.to(device0)
        #sa_encoded_passage = encoded_passage

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        sa_passage_question_similarity = self._l_matrix_attention(
            device1, sa_encoded_passage, sa_encoded_question, False)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask_temp)
        sa_passage_question_attention = util.masked_softmax(
            sa_passage_question_similarity.to(device0), question_mask_temp)

        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)
        sa_passage_question_vectors = util.weighted_sum(
            sa_encoded_question.to(device0), sa_passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask, -1e7)
        sa_masked_similarity = util.replace_masked_values(
            sa_passage_question_similarity.to(device0), question_mask, -1e7)

        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        sa_question_passage_similarity = sa_masked_similarity.max(
            dim=-1)[0].squeeze(-1)

        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        sa_question_passage_attention = util.masked_softmax(
            sa_question_passage_similarity, passage_mask)

        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        sa_question_passage_vector = util.weighted_sum(
            sa_encoded_passage, sa_question_passage_attention)

        # Shape: (batch_size, passage_length, encoding_dim)
        #print("Shape:",question_passage_vector.size(),question_passage_vector.unsqueeze(1).size())
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)
        #print("Shape:",sa_question_passage_vector.size(),sa_question_passage_vector.unsqueeze(1).size())
        sa_tiled_question_passage_vector = sa_question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        #print("Shape of SA Encoded:",sa_encoded_question.size(),sa_encoded_passage.size())
        #print("Required Shape of Encoded Passage:",encoded_passage.size(),passage_question_vectors.size())

        #sa_passage_question_vectors = passage_question_vectors
        #sa_tiled_question_passage_vector = tiled_question_passage_vector

        # Shape: (batch_size, passage_length, encoding_dim * 4 + 4*sa_dim )
        final_merged_passage = torch.cat([
            encoded_passage, sa_encoded_passage, passage_question_vectors,
            sa_passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector,
            sa_encoded_passage * sa_passage_question_vectors,
            sa_encoded_passage * sa_tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim + 2*selfattention_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)

        # Shape: (batch_size, 1000)
        span_start_logits_pad = (0, 1000 - passage_length, 0, 0)
        span_start_logits_w_na = self._span_start_predictor_w_na(
            pad(span_start_logits, span_start_logits_pad)).squeeze(-1)

        # Shape: (batch_size, passage_lenght+1)
        span_start_logits_w_na = span_start_logits_w_na[:, :passage_length + 1]

        span_start_na_logits = span_start_logits_w_na[:, 0]
        span_start_logits = span_start_logits_w_na[:, 1:]

        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3 + 4*sadim)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim +4*sadim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)

        # Shape: (batch_size, passage_length+1)
        span_end_logits_pad = (0, 1000 - passage_length, 0, 0)
        span_end_logits_w_na = self._span_end_predictor_w_na(
            pad(span_end_logits, span_end_logits_pad)).squeeze(-1)

        # Shape: (batch_size, passage_lenght+1)
        span_end_logits_w_na = span_end_logits_w_na[:, :passage_length + 1]

        span_end_na_logits = span_end_logits_w_na[:, 0]
        span_end_logits = span_end_logits_w_na[:, 1:]

        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        na_gt = (span_start == -1).type(torch.cuda.LongTensor)
        na_inv = (1.0 - na_gt)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
            # "na_logits": na_logits,
            # "na_probs": na_probs
        }

        # Compute the loss for training.
        if span_start is not None:
            y_start = span_start + 1
            y_end = span_end + 1
            passage_mask_w_na = torch.cat([
                torch.ones([batch_size, 1]).type(torch.cuda.FloatTensor),
                passage_mask
            ], -1)
            loss = 0.0

            # calculate loss if there is answer
            # loss for start
            preds_start = util.masked_log_softmax(
                span_start_logits_w_na.type(torch.cuda.FloatTensor),
                passage_mask_w_na.type(torch.cuda.FloatTensor)).type(
                    torch.cuda.FloatTensor)
            y_start = y_start.squeeze(-1).type(torch.cuda.LongTensor)
            loss += nll_loss(preds_start, y_start)

            # accuracy for start
            acc_p_start = na_inv.type(
                torch.cuda.FloatTensor) * span_start_logits.type(
                    torch.cuda.FloatTensor)
            acc_y_start = na_inv.squeeze(-1).type(
                torch.cuda.FloatTensor) * span_start.squeeze(-1).type(
                    torch.cuda.FloatTensor)
            self._span_start_accuracy(acc_p_start, acc_y_start)

            # loss for end
            preds_end = util.masked_log_softmax(
                span_end_logits_w_na.type(torch.cuda.FloatTensor),
                passage_mask_w_na.type(torch.cuda.FloatTensor)).type(
                    torch.cuda.FloatTensor)
            y_end = y_end.squeeze(-1).type(torch.cuda.LongTensor)
            loss += nll_loss(preds_end, y_end)

            # accuracy for end
            acc_p_end = na_inv.type(
                torch.cuda.FloatTensor) * span_end_logits.type(
                    torch.cuda.FloatTensor)
            acc_y_end = na_inv.squeeze(-1).type(
                torch.cuda.FloatTensor) * span_end.squeeze(-1).type(
                    torch.cuda.FloatTensor)
            self._span_end_accuracy(acc_p_end, acc_y_end)

            # accuracy for span
            acc_p = na_inv.type(torch.cuda.FloatTensor) * best_span.type(
                torch.cuda.FloatTensor)
            acc_y = na_inv.type(torch.cuda.FloatTensor) * torch.cat([
                span_start.type(torch.cuda.FloatTensor),
                span_end.type(torch.cuda.FloatTensor)
            ], -1)
            self._span_accuracy(acc_p, acc_y)

            output_dict["loss"] = loss

            preds_start = util.masked_softmax(
                span_start_logits_w_na.type(torch.cuda.FloatTensor),
                passage_mask_w_na.type(torch.cuda.FloatTensor)).type(
                    torch.cuda.FloatTensor)
            preds_end = util.masked_softmax(
                span_end_logits_w_na.type(torch.cuda.FloatTensor),
                passage_mask_w_na.type(torch.cuda.FloatTensor)).type(
                    torch.cuda.FloatTensor)

            output_dict["na_logits"] = preds_start[:, 0] * preds_end[:, 0]
            output_dict["na_probs"] = torch.stack(
                [1.0 - output_dict["na_logits"], output_dict["na_logits"]], -1)

            # calculate loss for answer existance
            self._na_accuracy(
                output_dict["na_probs"].type(torch.cuda.FloatTensor),
                na_gt.squeeze(-1).type(torch.cuda.FloatTensor))

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
    def forward(
        self,
        tokens_list: Dict[str, torch.LongTensor],
        positions_list: Dict[str, torch.LongTensor],
        sent_positions_list: Dict[str, torch.LongTensor],
        before_loc_start: torch.IntTensor = None,
        before_loc_end: torch.IntTensor = None,
        after_loc_start_list: torch.IntTensor = None,
        after_loc_end_list: torch.IntTensor = None,
        before_category: torch.IntTensor = None,
        after_category_list: torch.IntTensor = None,
        before_category_mask: torch.IntTensor = None,
        after_category_mask_list: torch.IntTensor = None
    ) -> Dict[str, torch.Tensor]:
        """
        :param tokens_list: Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        :param positions_list: same as tokens_list
        :param sent_positions_list: same as tokens_list
        :param before_loc_start: torch.IntTensor = None, required
            An integer ``IndexField`` representation of the before location start
        :param before_loc_end: torch.IntTensor = None, required
            An integer ``IndexField`` representation of the before location end
        :param after_loc_start_list: torch.IntTensor = None, required
            A list of integers ``ListField (IndexField)`` representation of the list of after location starts
            along the sequence of steps
        :param after_loc_end_list: torch.IntTensor = None, required
            A list of integers ``ListField (IndexField)`` representation of the list of after location ends
            along the sequence of steps
        :param before_category: torch.IntTensor = None, required
            An integer ``IndexField`` representation of the before location category
        :param after_category_list: torch.IntTensor = None, required
            A list of integers ``ListField (IndexField)`` representation of the list of after location categories
            along the sequence of steps
        :param before_category_mask: torch.IntTensor = None, required
            An integer ``IndexField`` representation of whether the before location is known or not (0/1)
        :param after_category_mask_list: torch.IntTensor = None, required
            A list of integers ``ListField (IndexField)`` representation of the list of whether after location is
            known or not for each step along the sequence of steps
        :return:
        An output dictionary consisting of:
        best_span: torch.FloatTensor
            A tensor of shape ``()``
        true_span: torch.FloatTensor
        loss: torch.FloatTensor
        """

        # batchsize * listLength * paragraphSize * embeddingSize
        input_embedding_paragraph = self._text_field_embedder(tokens_list)
        input_pos_embedding_paragraph = self._pos_field_embedder(
            positions_list)
        input_sent_pos_embedding_paragraph = self._sent_pos_field_embedder(
            sent_positions_list)
        # batchsize * listLength * paragraphSize * (embeddingSize*2)
        embedding_paragraph = torch.cat([
            input_embedding_paragraph, input_pos_embedding_paragraph,
            input_sent_pos_embedding_paragraph
        ],
                                        dim=-1)

        # batchsize * listLength * paragraphSize,  this mask is shared with the text fields and sequence label fields
        para_mask = util.get_text_field_mask(tokens_list,
                                             num_wrapping_dims=1).float()

        # batchsize * listLength ,  this mask is shared with the index fields
        para_index_mask, para_index_mask_indices = torch.max(para_mask, 2)

        # apply mask to update the index values,  padded instances will be 0
        after_loc_start_list = (after_loc_start_list.float() *
                                para_index_mask.unsqueeze(2)).long()
        after_loc_end_list = (after_loc_end_list.float() *
                              para_index_mask.unsqueeze(2)).long()
        after_category_list = (after_category_list.float() *
                               para_index_mask.unsqueeze(2)).long()
        after_category_mask_list = (after_category_mask_list.float() *
                                    para_index_mask.unsqueeze(2)).long()

        batch_size, list_size, paragraph_size, input_dim = embedding_paragraph.size(
        )

        # to store the values passed to next step
        tmp_category_probability = torch.zeros(batch_size, 3)
        tmp_start_probability = torch.zeros(batch_size, paragraph_size)

        loss = 0

        # store the predict logits for the whole lists
        category_predict_logits_after_list = torch.rand(
            batch_size, list_size, 3)
        best_span_after_list = torch.rand(batch_size, list_size, 2)

        for index in range(list_size):
            # get one slice of step for prediction
            embedding_paragraph_slice = embedding_paragraph[:,
                                                            index, :, :].squeeze(
                                                                1)
            para_mask_slice = para_mask[:, index, :].squeeze(1)
            para_lstm_mask_slice = para_mask_slice if self._mask_lstms else None
            para_index_mask_slice = para_index_mask[:, index]
            after_category_mask_slice = after_category_mask_list[:,
                                                                 index, :].squeeze(
                                                                 )

            # bi-LSTM: generate the contextual embeddings for the current step
            # size: batchsize * paragraph_size * modeling_layer_hidden_size
            encoded_paragraph = self._dropout(
                self._modeling_layer(embedding_paragraph_slice,
                                     para_lstm_mask_slice))

            # max-pooling output for three category classification
            category_input, category_input_indices = torch.max(
                encoded_paragraph, 1)

            modeling_dim = encoded_paragraph.size(-1)
            span_start_input = encoded_paragraph

            # predict the initial before location state
            if index == 0:

                # three category classification for initial before location
                category_predict_logits_before = self._category_before_predictor(
                    category_input)
                tmp_category_probability = category_predict_logits_before
                '''Model the before_loc prediction'''
                # predict the initial before location start scores
                # shape:  batchsize * paragraph_size
                span_start_logits_before = self._span_start_predictor_before(
                    span_start_input).squeeze(-1)
                # shape:  batchsize * paragraph_size
                span_start_probs_before = util.masked_softmax(
                    span_start_logits_before, para_mask_slice)
                tmp_start_probability = span_start_probs_before

                # shape:  batchsize * hiddensize
                span_start_representation_before = util.weighted_sum(
                    encoded_paragraph, span_start_probs_before)

                # Shape: (batch_size, passage_length, modeling_dim)
                tiled_start_representation_before = span_start_representation_before.unsqueeze(
                    1).expand(batch_size, paragraph_size, modeling_dim)

                # incorporate the original contextual embeddings and weighted sum vector from location start prediction
                # shape: batchsize * paragraph_size * 2hiddensize
                span_end_representation_before = torch.cat(
                    [encoded_paragraph, tiled_start_representation_before],
                    dim=-1)
                # Shape: (batch_size, passage_length, encoding_dim)
                encoded_span_end_before = self._dropout(
                    self._span_end_encoder_before(
                        span_end_representation_before, para_lstm_mask_slice))

                # initial before location end prediction
                encoded_span_end_before = torch.cat(
                    [encoded_paragraph, encoded_span_end_before], dim=-1)
                # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
                span_end_logits_before = self._span_end_predictor_before(
                    encoded_span_end_before).squeeze(-1)
                span_end_probs_before = util.masked_softmax(
                    span_end_logits_before, para_mask_slice)

                # best_span_bef = self._get_best_span(span_start_logits_bef, span_end_logits_bef)
                best_span_before, best_span_before_start, best_span_before_end, best_span_before_real = \
                    self._get_best_span_single_extend(span_start_logits_before, span_end_logits_before,
                                                      category_predict_logits_before, before_category_mask)

                # compute the loss for initial bef location three-category classification
                before_null_pred = softmax(category_predict_logits_before)
                before_null_pred_values, before_null_pred_indices = torch.max(
                    before_null_pred, 1)
                loss += nll_loss(before_null_pred, before_category.squeeze(-1))

                # compute the loss for initial bef location start/end prediction
                before_loc_start_pred = util.masked_softmax(
                    span_start_logits_before, para_mask_slice)
                logpy_before_start = torch.gather(
                    before_loc_start_pred, 1,
                    before_loc_start).view(-1).float()
                before_category_mask = before_category_mask.float()
                loss += -(logpy_before_start * before_category_mask).mean()
                before_loc_end_pred = util.masked_softmax(
                    span_end_logits_before, para_mask_slice)
                logpy_before_end = torch.gather(before_loc_end_pred, 1,
                                                before_loc_end).view(-1)
                loss += -(logpy_before_end * before_category_mask).mean()

                # get the real predicted location spans
                # convert category output (Null and Unk) into spans ((-2,-2) or (-1, -1))
                before_loc_start_real = self._get_real_spans_extend(
                    before_loc_start, before_category, before_category_mask)
                before_loc_end_real = self._get_real_spans_extend(
                    before_loc_end, before_category, before_category_mask)
                true_span_before = torch.stack(
                    [before_loc_start_real, before_loc_end_real], dim=-1)
                true_span_before = true_span_before.squeeze(1)

            # input for (after location) three category classification
            category_input_after = torch.cat(
                (category_input, tmp_category_probability), dim=1)
            category_predict_logits_after = self._category_after_predictor(
                category_input_after)
            tmp_category_probability = category_predict_logits_after

            # copy the predict logits for the index of the list
            category_predict_logits_after_tmp = category_predict_logits_after.unsqueeze(
                1)
            category_predict_logits_after_list[:,
                                               index, :] = category_predict_logits_after_tmp.data
            '''  Model the after_loc prediction  '''
            # after location start prediction: takes contextual embeddings and weighted sum vector as input
            # shape:  batchsize * hiddensize
            prev_start = util.weighted_sum(category_input,
                                           tmp_start_probability)
            tiled_prev_start = prev_start.unsqueeze(1).expand(
                batch_size, paragraph_size, modeling_dim)
            span_start_input_after = torch.cat(
                (span_start_input, tiled_prev_start), dim=2)
            encoded_start_input_after = self._dropout(
                self._span_start_encoder_after(span_start_input_after,
                                               para_lstm_mask_slice))
            span_start_input_after_cat = torch.cat(
                [encoded_paragraph, encoded_start_input_after], dim=-1)

            # predict the after location start
            span_start_logits_after = self._span_start_predictor_after(
                span_start_input_after_cat).squeeze(-1)
            # shape:  batchsize * paragraph_size
            span_start_probs_after = util.masked_softmax(
                span_start_logits_after, para_mask_slice)
            tmp_start_probability = span_start_probs_after

            # after location end prediction: takes contextual embeddings and weight sum vector as input
            # shape:  batchsize * hiddensize
            span_start_representation_after = util.weighted_sum(
                encoded_paragraph, span_start_probs_after)
            # Tensor Shape: (batch_size, passage_length, modeling_dim)
            tiled_start_representation_after = span_start_representation_after.unsqueeze(
                1).expand(batch_size, paragraph_size, modeling_dim)
            # shape: batchsize * paragraph_size * 2hiddensize
            span_end_representation_after = torch.cat(
                [encoded_paragraph, tiled_start_representation_after], dim=-1)
            # Tensor Shape: (batch_size, passage_length, encoding_dim)
            encoded_span_end_after = self._dropout(
                self._span_end_encoder_after(span_end_representation_after,
                                             para_lstm_mask_slice))
            encoded_span_end_after = torch.cat(
                [encoded_paragraph, encoded_span_end_after], dim=-1)
            # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
            span_end_logits_after = self._span_end_predictor_after(
                encoded_span_end_after).squeeze(-1)
            span_end_probs_after = util.masked_softmax(span_end_logits_after,
                                                       para_mask_slice)

            # get the best span for after location prediction
            best_span_after, best_span_after_start, best_span_after_end, best_span_after_real = \
                self._get_best_span_single_extend(span_start_logits_after, span_end_logits_after,
                                                  category_predict_logits_after, after_category_mask_slice)

            # copy current best span to the list for final evaluation
            best_span_after_list[:, index, :] = best_span_after.data.view(
                batch_size, 1, 2)
            """ Compute the Loss for this slice """
            after_category_mask = after_category_mask_slice.float().squeeze(
                -1)  # batchsize
            after_category_slice = after_category_list[:,
                                                       index, :]  # batchsize * 1
            after_loc_start_slice = after_loc_start_list[:, index, :]
            after_loc_end_slice = after_loc_end_list[:, index, :]

            # compute the loss for (after location) three category classification
            para_index_mask_slice_tiled = para_index_mask_slice.unsqueeze(
                1).expand(para_index_mask_slice.size(0), 3)
            after_category_pred = util.masked_softmax(
                category_predict_logits_after, para_index_mask_slice_tiled)
            logpy_after_category = torch.gather(after_category_pred, 1,
                                                after_category_slice).view(-1)
            loss += -(logpy_after_category * para_index_mask_slice).mean()

            # compute the loss for location start/end prediction
            after_loc_start_pred = util.masked_softmax(span_start_logits_after,
                                                       para_mask_slice)
            logpy_after_start = torch.gather(after_loc_start_pred, 1,
                                             after_loc_start_slice).view(-1)
            loss += -(logpy_after_start * after_category_mask).mean()
            after_loc_end_pred = util.masked_softmax(span_end_logits_after,
                                                     para_mask_slice)
            logpy_after_end = torch.gather(after_loc_end_pred, 1,
                                           after_loc_end_slice).view(-1)
            loss += -(logpy_after_end * after_category_mask).mean()

        # for evaluation  (combine the all annotations)
        after_loc_start_real = self._get_real_spans_extend_list(
            after_loc_start_list, after_category_list,
            after_category_mask_list)
        after_loc_end_real = self._get_real_spans_extend_list(
            after_loc_end_list, after_category_list, after_category_mask_list)

        true_span_after = torch.stack(
            [after_loc_start_real, after_loc_end_real], dim=-1)
        true_span_after = true_span_after.squeeze(2)
        best_span_after_list = Variable(best_span_after_list)

        true_span_after = true_span_after.view(
            true_span_after.size(0) * true_span_after.size(1),
            true_span_after.size(2)).float()

        para_index_mask_tiled = para_index_mask.view(-1, 1)
        para_index_mask_tiled = para_index_mask_tiled.expand(
            para_index_mask_tiled.size(0), 2)

        para_index_mask_tiled2 = para_index_mask.unsqueeze(2).expand(
            para_index_mask.size(0), para_index_mask.size(1), 2)
        after_category_mask_list_tiled = after_category_mask_list.expand(
            batch_size, list_size, 2)
        after_category_mask_list_tiled = after_category_mask_list_tiled * para_index_mask_tiled2.long(
        )

        # merge all the best spans predicted for the current batch, filter out the padded instances
        merged_sys_span, merged_gold_span = self._get_merged_spans(
            true_span_before, best_span_before, true_span_after,
            best_span_after_list, para_index_mask_tiled)

        output_dict = {}
        output_dict["best_span"] = merged_sys_span.view(
            1,
            merged_sys_span.size(0) * merged_sys_span.size(1))
        output_dict["true_span"] = merged_gold_span.view(
            1,
            merged_gold_span.size(0) * merged_gold_span.size(1))
        output_dict["loss"] = loss
        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, num_of_passage_tokens = passage['bert'].size()

        # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of
        # start and end spans.
        embedded_passage = self._text_field_embedder(passage)
        passage_length = embedded_passage.size(1)
        logits = self.qa_outputs(embedded_passage)
        start_logits, end_logits = logits.split(1, dim=-1)
        span_start_logits = start_logits.squeeze(-1)
        span_end_logits = end_logits.squeeze(-1)

        # Adding some masks with numerically stable values
        passage_mask = util.get_text_field_mask(passage).float()
        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1)
        repeated_passage_mask = repeated_passage_mask.view(
            batch_size, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       repeated_passage_mask,
                                                       -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     repeated_passage_mask,
                                                     -1e7)

        output_dict: Dict[str, Any] = {}
        # add span start and end logits for knowledge distillation
        output_dict: Dict[str, Any] = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
        }

        # We may have multiple instances per questions, moving to per-question
        intances_question_id = [
            insta_meta['question_id'] for insta_meta in metadata
        ]
        question_instances_split_inds = np.cumsum(
            np.unique(intances_question_id, return_counts=True)[1])[:-1]
        per_question_inds = np.split(range(batch_size),
                                     question_instances_split_inds)
        metadata = np.split(metadata, question_instances_split_inds)

        # Compute the loss.
        # if span_start is not None and len(np.argwhere(span_start.squeeze().cpu() >= 0)) > 0:
        if span_start is not None and len(
                np.argwhere(
                    span_start.squeeze(-1).squeeze(-1).cpu() >= 0)) > 0:
            # in evaluation some instances may not contain the gold answer, so we need to compute
            # loss only on those that do.
            inds_with_gold_answer = np.argwhere(
                span_start.view(-1).cpu().numpy() >= 0)
            inds_with_gold_answer = inds_with_gold_answer.squeeze(
            ) if len(inds_with_gold_answer) > 1 else inds_with_gold_answer
            if len(inds_with_gold_answer) > 0:
                loss = nll_loss(util.masked_log_softmax(span_start_logits[inds_with_gold_answer], \
                                                    repeated_passage_mask[inds_with_gold_answer]),\
                                span_start.view(-1)[inds_with_gold_answer], ignore_index=-1)
                output_dict["loss_start"] = loss
                loss += nll_loss(util.masked_log_softmax(span_end_logits[inds_with_gold_answer], \
                                                    repeated_passage_mask[inds_with_gold_answer]),\
                                span_end.view(-1)[inds_with_gold_answer], ignore_index=-1)
                output_dict["loss"] = loss
                output_dict["loss_end"] = loss - output_dict["loss_start"]

        # This is a hack for cases in which gold answer is not provided so we cannot compute loss...
        if 'loss' not in output_dict:
            output_dict["loss"] = torch.cuda.FloatTensor([0], device=span_end_logits.device) \
                if torch.cuda.is_available() else torch.FloatTensor([0])

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict["start_bias_weight"] = []
        output_dict["end_bias_weight"] = []

        # getting best span prediction for
        best_span = self._get_example_predications(span_start_logits,
                                                   span_end_logits,
                                                   self._max_span_length)
        best_span_cpu = best_span.detach().cpu().numpy()

        span_start_logits_numpy = span_start_logits.data.cpu().numpy()
        span_end_logits_numpy = span_end_logits.data.cpu().numpy()
        # Iterating over every question (which may contain multiple instances, one per chunk)
        for question_inds, question_instances_metadata in zip(
                per_question_inds, metadata):
            best_span_ind = np.argmax(
                span_start_logits_numpy[question_inds,
                                        best_span_cpu[question_inds][:, 0]] +
                span_end_logits_numpy[question_inds,
                                      best_span_cpu[question_inds][:, 1]])
            best_span_logit = np.max(
                span_start_logits_numpy[question_inds,
                                        best_span_cpu[question_inds][:, 0]] +
                span_end_logits_numpy[question_inds,
                                      best_span_cpu[question_inds][:, 1]])

            passage_str = question_instances_metadata[best_span_ind][
                'original_passage']
            offsets = question_instances_metadata[best_span_ind][
                'token_offsets']

            predicted_span = best_span_cpu[question_inds[best_span_ind]]
            start_offset = offsets[predicted_span[0]][0]
            end_offset = offsets[predicted_span[1]][1]
            best_span_string = passage_str[start_offset:end_offset]

            # Note: this is a hack, because AllenNLP, when predicting, expects a value for each instance.
            # But we may have more than 1 chunk per question, and thus less output strings than instances
            for i in range(len(question_inds)):
                output_dict['best_span_str'].append(best_span_string)
                output_dict['qid'].append(
                    question_instances_metadata[best_span_ind]['question_id'])

                # get the scalar logit value of the predicted span start and end index as bias weight.
                output_dict["start_bias_weight"].append(
                    util.masked_softmax(span_start_logits[best_span_ind],
                                        repeated_passage_mask[best_span_ind])[
                                            best_span_cpu[best_span_ind][0]])
                output_dict["end_bias_weight"].append(
                    util.masked_softmax(span_end_logits[best_span_ind],
                                        repeated_passage_mask[best_span_ind])[
                                            best_span_cpu[best_span_ind][1]])

            f1_score = 0.0
            EM_score = 0.0
            gold_answer_texts = question_instances_metadata[best_span_ind][
                'answer_texts_list']
            if gold_answer_texts:
                f1_score = squad_eval.metric_max_over_ground_truths(
                    squad_eval.f1_score, best_span_string, gold_answer_texts)
                EM_score = squad_eval.metric_max_over_ground_truths(
                    squad_eval.exact_match_score, best_span_string,
                    gold_answer_texts)
            self._official_f1(100 * f1_score)
            self._official_EM(100 * EM_score)

            # TODO move to predict
            if self._predictions_file is not None:
                with open(self._predictions_file, 'a') as f:
                    f.write(json.dumps({'question_id':question_instances_metadata[best_span_ind]['question_id'], \
                                'best_span_logit':float(best_span_logit), \
                                'f1':100 * f1_score,
                                'EM':100 * EM_score,
                                'best_span_string':best_span_string,\
                                'gold_answer_texts':gold_answer_texts, \
                                'qas_used_fraction':1.0}) + '\n')

        return output_dict
示例#33
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        # encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))

        # # v5:
        # # remember to set token embeddings in the CONFIG JSON
        encoded_question = self._dropout(embedded_question)

        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX
        similarity_matrix = self._matrix_attention(encoded_passage,
                                                   encoded_question)

        # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY
        passage_question_attention = util.last_dim_softmax(
            similarity_matrix, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Our custom query2context
        q2c_attention = util.masked_softmax(similarity_matrix,
                                            question_mask,
                                            dim=1).transpose(-1, -2)
        q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention)

        # Now we try the various variants
        # v1:
        # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention)

        # v2:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1]))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2)

        # v3:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim)

        # v4:
        # Re-application of query2context attention
        # new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs)
        # masked_similarity = util.replace_masked_values(new_similarity_matrix,
        #                                                question_mask.unsqueeze(1),
        #                                                -1e7)
        # # Shape: (batch_size, passage_length)
        # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # # Shape: (batch_size, passage_length)
        # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # # Shape: (batch_size, encoding_dim)
        # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # # Shape: (batch_size, passage_length, encoding_dim)
        # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
        #                                                                             passage_length,
        #                                                                             encoding_dim)

        # ------- Original variant
        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            similarity_matrix, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # ------- END

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        # original beta combination function
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        # # v6:
        # final_merged_passage = torch.cat([tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v7:
        # final_merged_passage = torch.cat([passage_question_vectors],
        #                                  dim=-1)
        #
        # # v8:
        # final_merged_passage = torch.cat([passage_question_vectors,
        #                                   tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v9:
        # final_merged_passage = torch.cat([encoded_passage,
        #                                   passage_question_vectors,
        #                                   encoded_passage * passage_question_vectors],
        #                                  dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
示例#34
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None,
                get_sample_level_information = True) -> Dict[str, torch.Tensor]:
        """
        WE LOAD THE MODELS ONE INTO GPU  ONE AT A TIME !!!
        """
        
        subresults = []
        for submodel in self.submodels:
            submodel.to(device = submodel.cf_a.device)
            subres = submodel(question, passage, span_start, span_end, metadata, get_sample_level_information)
            submodel.to(device = torch.device("cpu"))
            subresults.append(subres)

        batch_size = len(subresults[0]["best_span"])

        best_span = merge_span_probs(subresults)
        output = {
                "best_span": best_span,
                "best_span_str": [],
                "models_output": subresults
        }
        if (get_sample_level_information):
            output["em_samples"] = []
            output["f1_samples"] = []
                
        for index in range(batch_size):
            if metadata is not None:
                passage_str = metadata[index]['original_passage']
                offsets = metadata[index]['token_offsets']
                predicted_span = tuple(best_span[index].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output["best_span_str"].append(best_span_string)

                answer_texts = metadata[index].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
                    if (get_sample_level_information):
                        em_sample, f1_sample = bidut.get_em_f1_metrics(best_span_string,answer_texts)
                        output["em_samples"].append(em_sample)
                        output["f1_samples"].append(f1_sample)
        if (get_sample_level_information):
            # Add information about the individual samples for future analysis
            output["span_start_sample_loss"] = []
            output["span_end_sample_loss"] = []
            for i in range (batch_size):
                
                span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
                span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
                span_start_loss = nll_loss(span_start_probs[[i],:], span_start.squeeze(-1)[[i]])
                span_end_loss = nll_loss(span_end_probs[[i],:], span_end.squeeze(-1)[[i]])
                
                output["span_start_sample_loss"].append(float(span_start_loss.detach().cpu().numpy()))
                output["span_end_sample_loss"].append(float(span_end_loss.detach().cpu().numpy()))
        return output