def forward(self,
                context: Dict[str, torch.LongTensor],
                response: Dict[str, torch.LongTensor],
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        embedded_context = self.text_field_embedder(context)
        context_mask = get_text_field_mask(context).float()

        embedded_response = self.text_field_embedder(response)
        response_mask = get_text_field_mask(response).float()

        if self.context_encoder:
            embedded_context = self.context_encoder(embedded_context, context_mask)
        if self.response_encoder:
            embedded_response = self.response_encoder(embedded_response, response_mask)

        projected_context = self.attend_feedforward(embedded_context)
        projected_response = self.attend_feedforward(embedded_response)
        # batch x context_length x response_length
        similarity_matrix = self.matrix_attention(projected_context, projected_response)

        # batch x context_length x response_length
        c2r_attention = last_dim_softmax(similarity_matrix, response_mask)
        # batch x context_length x embedded_context_dim
        attended_response = weighted_sum(embedded_response, c2r_attention)

        r2c_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), context_mask)
        attended_context = weighted_sum(embedded_context, r2c_attention)

        # batch x context_length x embedded_context_dim + attended_response_dim
        context_compare_input = torch.cat([embedded_context, attended_response], dim=-1)
        response_compare_input = torch.cat([embedded_response, attended_context], dim=-1)

        compared_context = self.compare_feedforward(context_compare_input)
        compared_context = compared_context * context_mask.unsqueeze(-1)
        # batch x compare_dim
        compared_context = compared_context.sum(dim=1)

        compared_response = self.compare_feedforward(response_compare_input)
        compared_response = compared_response * response_mask.unsqueeze(-1)
        compared_response = compared_response.sum(dim=1)


        # batch x compare_context_dim + compared_response_dim
        aggregate_input = torch.cat([compared_context, compared_response], dim=-1)

        class_logits = self.classifier_feedforward(aggregate_input)

        class_probs = F.softmax(class_logits, dim=-1)

        output_dict = {"class_logits": class_logits, "class_probabilities": class_probs}

        if label is not None:
            loss = self.loss(class_logits, label.squeeze(-1))
            for metric in self.metrics.values():
                metric(class_logits, label.squeeze(-1))
            output_dict['loss'] = loss

        return output_dict
    def forward(
            self,  # type: ignore
            sentence: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        sentence_mask = util.get_text_field_mask(sentence).float()
        embedded_sentence = self._text_field_embedder(sentence)

        dropped_embedded_sent = self._embedding_dropout(embedded_sentence)
        pre_encoded_sent = self._pre_encode_feedforward(dropped_embedded_sent)
        encoded_tokens = self._encoder(pre_encoded_sent, sentence_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(
            encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.last_dim_softmax(attention_logits,
                                                  sentence_mask)
        encoded_sentence = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([
            encoded_tokens, encoded_tokens - encoded_sentence,
            encoded_tokens * encoded_sentence
        ], 2)
        integrated_encodings = self._integrator(integrator_input,
                                                sentence_mask)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, sentence_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, sentence_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(
            sentence_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
            integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits,
                                           sentence_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings,
                                                self_weights)

        pooled_representations = torch.cat(
            [max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(
            pooled_representations).squeeze(1)

        logits = self._output_layer(pooled_representations_dropped)
        output_dict = {'logits': logits}
        if label is not None:
            loss = self.loss(logits, label.squeeze(-1))
            for metric in self.metrics.values():
                metric(logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
Exemple #3
0
    def posAttnConv(self, sentence, other_sen, interaction, sentence_mask,
                    other_sen_mask, matrix_mask):
        """
        @brief      Compute the position-aware attentive convolution

        @param      self            The object
        @param      sentence        The embeded sentence (n x s x d)
        @param      other_sen       The other sentence (n x s' x d)
        @param      interaction     The interaction matrix (n x s x s')
        @param      sentence_mask   The mask of the sentence (n x s)
        @param      other_sen_mask  The mask of other sentence (n x s')
        @param      matrix_mask     The mask of the interaction matrix (n x s x
                                    s')

        @return     The position-aware attentive convolution
        """
        # calculate the representation of the sentence
        interaction_softmax = last_dim_softmax(
            interaction, other_sen_mask)  # (n x s x s')
        sentence_tilda = weighted_sum(
            other_sen, interaction_softmax)  # (n x s x d)

        # get index of the best-matched word
        _, x = replace_masked_values(interaction, matrix_mask,
                                     -1e7).max(dim=-1)  # (n x s)
        z = self._pos_embedder(x)  # (n x s x dm)

        sentence_combined = torch.cat((sentence_tilda, sentence, z),
                                      dim=2)  # (n x s x (2d + dm))

        return self._pos_attn_encoder(sentence_combined, sentence_mask)
Exemple #4
0
    def _get_node_probabilities(self, embedded_nodes, embedded_premise,
                                nodes_mask, premise_mask,
                                metadata) -> Tuple[FloatTensor, FloatTensor]:
        """
        Compute the average entailment distribution based on the nodes in the hypothesis.
        Returns a tuple of (attention of each node over the premise, average entailment
        distribution) with dimensions batch x nodes x premise words and batch x num classes
        respectively.
        """
        # attention for each node. dim: batch x nodes x node words x premise words
        node_premise_attention = self._nodes_attention(embedded_nodes,
                                                       embedded_premise)

        normalized_node_premise_attention = last_dim_softmax(
            node_premise_attention, premise_mask)

        expanded_nodes_mask_premise = nodes_mask.unsqueeze(-1).expand_as(
            normalized_node_premise_attention).float()

        # aggregate representation. dim: batch x nodes x premise words
        mean_node_premise_attention = masked_mean(
            normalized_node_premise_attention, 2, expanded_nodes_mask_premise)

        # convert batch x nodes and batch x premise to batch x nodes x premise mask
        nodes_only_mask = (torch.sum(nodes_mask, -1) > 0).float()
        node_premise_mask = nodes_only_mask.unsqueeze(-1).expand_as(mean_node_premise_attention) \
                            * premise_mask.unsqueeze(1).expand_as(mean_node_premise_attention)
        masked_mean_node_premise_attention = replace_masked_values(
            mean_node_premise_attention, node_premise_mask, 0)
        # aggreate node representation over premise. dim: batch x nodes x emb. dim
        aggregate_node_premise_representation = weighted_sum(
            embedded_premise, masked_mean_node_premise_attention)
        expanded_nodes_mask_embedding = nodes_mask.unsqueeze(-1).expand_as(
            embedded_nodes).float()
        # dim: batch x nodes x emb. dim
        aggregate_node_representation = masked_mean(
            embedded_nodes, 2, expanded_nodes_mask_embedding)

        sub_representation = aggregate_node_premise_representation - aggregate_node_representation
        dot_representation = aggregate_node_premise_representation * aggregate_node_representation
        # dim: batch x nodes x emb. dim * 4
        combined_node_representation = torch.cat([
            aggregate_node_premise_representation,
            aggregate_node_representation, sub_representation,
            dot_representation
        ], 2)
        # dim: batch x nodes x num_classes
        phrase_prob_distribution = self._phrase_probability(
            combined_node_representation)

        # ignore nodes with no text and expand to num of output classes
        # dim: batch x node x node words -> batch x node  -> batch x node x num_classes
        nodes_class_mask = nodes_only_mask.unsqueeze(-1).expand_as(
            phrase_prob_distribution).float()

        mean_phrase_distribution = masked_mean(phrase_prob_distribution, 1,
                                               nodes_class_mask)
        return mean_node_premise_attention, mean_phrase_distribution
Exemple #5
0
    def forward(self,
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # Embed premise and hypothesis
        premise = self._text_field_embedder(premise)  # (n x p x d)
        hypothesis = self._text_field_embedder(hypothesis)  # (n x h x d)

        # encode premise and hypothesis
        # (n x p x 2d) if bidirectional else (n x p x d)
        premise = self._encoder(premise, premise_mask)
        # (n x h x 2d) if bidirectional else (n x h x d)
        hypothesis = self._encoder(hypothesis, hypothesis_mask)

        # calculate matrix attention
        similarity_matrix = self._inter_attention(hypothesis,
                                                  premise)  # (n x h x p)

        attention_softmax = last_dim_softmax(similarity_matrix,
                                             premise_mask)  # (n x h x p)
        hypothesis_tilda = weighted_sum(
            premise, attention_softmax
        )  # (n x h x 2d) assuming encoder is bidirectional

        hypothesis_matching_states = torch.cat([
            hypothesis, hypothesis_tilda, hypothesis - hypothesis_tilda,
            hypothesis * hypothesis_tilda
        ],
                                               dim=-1)

        # max pool
        hypothesis_max, _ = replace_masked_values(
            hypothesis_matching_states, hypothesis_mask.unsqueeze(-1),
            -1e7).max(dim=1)  # (n x 2d)

        output_dict = {"final_hidden": hypothesis_max}

        if self._output_feedforward:
            label_logits = self._output_feedforward(hypothesis_max)
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
            output_dict["label_logits"] = label_logits
            output_dict["label_probs"] = label_probs

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Exemple #6
0
    def forward(self, inputs, lengths):
        # 1. run LSTM
        # apply dropout to the input
        # Shape of inputs: (batch_size, sequence_length, embedding_dim)
        embedded_input = self.dropout_on_input_to_LSTM(inputs)
        # Sort the embedded inputs by decreasing order of input length.
        # sorted_input shape: (batch_size, sequence_length, embedding_dim)
        (sorted_input, sorted_lengths, input_unsort_indices,
         _) = sort_batch_by_length(embedded_input, lengths)
        # Pack the sorted inputs with pack_padded_sequence.
        packed_input = pack_padded_sequence(sorted_input,
                                            sorted_lengths.data.tolist(),
                                            batch_first=True)
        # Run the input through the RNN.
        packed_sorted_output, _ = self.rnn(packed_input)
        # Unpack (pad) the input with pad_packed_sequence
        # Shape: (batch_size, sequence_length, hidden_size)
        sorted_output, _ = pad_packed_sequence(packed_sorted_output,
                                               batch_first=True)
        # Re-sort the packed sequence to restore the initial ordering
        # Shape: (batch_size, sequence_length, hidden_size)
        output = sorted_output[input_unsort_indices]

        # 2. use attention
        # Shape: (batch_size, sequence_length, 1)
        # Shape: (batch_size, sequence_length) after squeeze
        attention_logits = self.attention_weights(output).squeeze(dim=-1)
        mask_attention_logits = (attention_logits != 0).type(
            torch.cuda.FloatTensor if inputs.is_cuda else torch.FloatTensor)
        # Shape: (batch_size, sequence_length)
        softmax_attention_logits = last_dim_softmax(attention_logits,
                                                    mask_attention_logits)
        # Shape: (batch_size, 1, sequence_length)
        softmax_attention_logits = softmax_attention_logits.unsqueeze(dim=1)
        # Shape of input_encoding: (batch_size, 1, hidden_size )
        #    output: (batch_size, sequence_length, hidden_size)
        #    softmax_attention_logits: (batch_size, 1, sequence_length)
        input_encoding = torch.bmm(softmax_attention_logits, output)
        # Shape: (batch_size, hidden_size)
        input_encoding = input_encoding.squeeze(dim=1)

        # 3. run linear layer
        # apply dropout to input to the linear layer
        input_encoding = self.dropout_on_input_to_linear_layer(input_encoding)
        # Run the RNN encoding of the input through the output projection
        # to get scores for each of the classes.
        unnormalized_output = self.output_projection(input_encoding)
        # Normalize with log softmax
        output_distribution = F.log_softmax(unnormalized_output, dim=-1)
        return output_distribution
Exemple #7
0
    def forward(self, s1, s2, s1_mask, s2_mask):  # pylint: disable=arguments-differ
        """ """
        # Similarity matrix
        # Shape: (batch_size, s2_length, s1_length)
        similarity_mat = self._matrix_attention(s2, s1)

        # s2 representation
        # Shape: (batch_size, s2_length, s1_length)
        s2_s1_attn = util.last_dim_softmax(similarity_mat, s1_mask)
        # Shape: (batch_size, s2_length, encoding_dim)
        s2_s1_vectors = util.weighted_sum(s1, s2_s1_attn)
        # batch_size, seq_len, 4*enc_dim
        s2_w_context = torch.cat([s2, s2_s1_vectors], 2)

        # s1 representation, using same attn method as for the s2 representation
        s1_s2_attn = util.last_dim_softmax(
            similarity_mat.transpose(1, 2).contiguous(), s2_mask)
        # Shape: (batch_size, s1_length, encoding_dim)
        s1_s2_vectors = util.weighted_sum(s2, s1_s2_attn)
        s1_w_context = torch.cat([s1, s1_s2_vectors], 2)

        modeled_s1 = self._dropout(self._modeling_layer(s1_w_context, s1_mask))
        modeled_s2 = self._dropout(self._modeling_layer(s2_w_context, s2_mask))
        return modeled_s1, modeled_s2
Exemple #8
0
    def forward(self, tokens, mask):  # pylint: disable=arguments-differ
        batch_size, sequence_length, _ = tokens.size()
        # Shape: (batch_size, sequence_length, sequence_length)
        similarity_matrix = self._matrix_attention(tokens, tokens)

        if self._num_attention_heads > 1:
            # In this case, the similarity matrix actually has shape
            # (batch_size, sequence_length, sequence_length, num_heads).  To make the rest of the
            # logic below easier, we'll permute this to
            # (batch_size, sequence_length, num_heads, sequence_length).
            similarity_matrix = similarity_matrix.permute(0, 1, 3, 2)

        # Shape: (batch_size, sequence_length, [num_heads,] sequence_length)
        intra_sentence_attention = util.last_dim_softmax(
            similarity_matrix.contiguous(), mask)

        # Shape: (batch_size, sequence_length, projection_dim)
        output_token_representation = self._projection(tokens)

        if self._num_attention_heads > 1:
            # We need to split and permute the output representation to be
            # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can
            # do a proper weighted sum with `intra_sentence_attention`.
            shape = list(output_token_representation.size())
            new_shape = shape[:-1] + [self._num_attention_heads, -1]
            # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads)
            output_token_representation = output_token_representation.view(
                *new_shape)
            # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads)
            output_token_representation = output_token_representation.permute(
                0, 2, 1, 3)

        # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads])
        attended_sentence = util.weighted_sum(output_token_representation,
                                              intra_sentence_attention)

        if self._num_attention_heads > 1:
            # Here we concatenate the weighted representation for each head.  We'll accomplish this
            # just with a resize.
            # Shape: (batch_size, sequence_length, projection_dim)
            attended_sentence = attended_sentence.view(batch_size,
                                                       sequence_length, -1)

        # Shape: (batch_size, sequence_length, combination_dim)
        combined_tensors = util.combine_tensors(self._combination,
                                                [tokens, attended_sentence])
        return self._output_projection(combined_tensors)
Exemple #9
0
 def test_last_dim_softmax_does_softmax_on_last_dim(self):
     batch_size = 1
     length_1 = 5
     length_2 = 3
     num_options = 4
     options_array = numpy.zeros((batch_size, length_1, length_2, num_options))
     for i in range(length_1):
         for j in range(length_2):
             options_array[0, i, j] = [2, 4, 0, 1]
     options_tensor = Variable(torch.from_numpy(options_array))
     softmax_tensor = util.last_dim_softmax(options_tensor).data.numpy()
     assert softmax_tensor.shape == (batch_size, length_1, length_2, num_options)
     for i in range(length_1):
         for j in range(length_2):
             assert_almost_equal(softmax_tensor[0, i, j],
                                 [0.112457, 0.830953, 0.015219, 0.041371],
                                 decimal=5)
Exemple #10
0
 def test_last_dim_softmax_handles_mask_correctly(self):
     batch_size = 1
     length_1 = 4
     length_2 = 3
     num_options = 5
     options_array = numpy.zeros((batch_size, length_1, length_2, num_options))
     for i in range(length_1):
         for j in range(length_2):
             options_array[0, i, j] = [2, 4, 0, 1, 6]
     mask = Variable(torch.IntTensor([[1, 1, 1, 1, 0]]))
     options_tensor = Variable(torch.from_numpy(options_array).float())
     softmax_tensor = util.last_dim_softmax(options_tensor, mask).data.numpy()
     assert softmax_tensor.shape == (batch_size, length_1, length_2, num_options)
     for i in range(length_1):
         for j in range(length_2):
             assert_almost_equal(softmax_tensor[0, i, j],
                                 [0.112457, 0.830953, 0.015219, 0.041371, 0.0],
                                 decimal=5)
    def forward(self, tokens: torch.Tensor, mask: torch.Tensor):  # pylint: disable=arguments-differ
        batch_size, sequence_length, _ = tokens.size()
        # Shape: (batch_size, sequence_length, sequence_length)
        similarity_matrix = self._matrix_attention(tokens, tokens)

        if self._num_attention_heads > 1:
            # In this case, the similarity matrix actually has shape
            # (batch_size, sequence_length, sequence_length, num_heads).  To make the rest of the
            # logic below easier, we'll permute this to
            # (batch_size, sequence_length, num_heads, sequence_length).
            similarity_matrix = similarity_matrix.permute(0, 1, 3, 2)

        # Shape: (batch_size, sequence_length, [num_heads,] sequence_length)
        intra_sentence_attention = util.last_dim_softmax(similarity_matrix.contiguous(), mask)

        # Shape: (batch_size, sequence_length, projection_dim)
        output_token_representation = self._projection(tokens)

        if self._num_attention_heads > 1:
            # We need to split and permute the output representation to be
            # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can
            # do a proper weighted sum with `intra_sentence_attention`.
            shape = list(output_token_representation.size())
            new_shape = shape[:-1] + [self._num_attention_heads, -1]
            # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads)
            output_token_representation = output_token_representation.view(*new_shape)
            # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads)
            output_token_representation = output_token_representation.permute(0, 2, 1, 3)

        # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads])
        attended_sentence = util.weighted_sum(output_token_representation,
                                              intra_sentence_attention)

        if self._num_attention_heads > 1:
            # Here we concatenate the weighted representation for each head.  We'll accomplish this
            # just with a resize.
            # Shape: (batch_size, sequence_length, projection_dim)
            attended_sentence = attended_sentence.view(batch_size, sequence_length, -1)

        # Shape: (batch_size, sequence_length, combination_dim)
        combined_tensors = util.combine_tensors(self._combination, [tokens, attended_sentence])
        return self._output_projection(combined_tensors)
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        real_batch_size = embedded_question.size(0)
        batch_size = embedded_passage.size(0)
        max_p_num = batch_size // real_batch_size
        assert batch_size % real_batch_size == 0, 'fake:{}, real:{}'.format(batch_size, real_batch_size)
        q_shape = embedded_question.shape
        embedded_question = embedded_question.unsqueeze(dim=1).expand(
            q_shape[0], max_p_num, q_shape[1], q_shape[2]).contiguous().view(-1, q_shape[1], q_shape[2])
        assert embedded_question.shape[0] == batch_size
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        question_mask = question_mask.unsqueeze(dim=1).expand(q_shape[0], max_p_num, q_shape[1]).contiguous().view(-1, q_shape[1])
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        #logger.info('p shape:{} q.shape:{}'.format(encoded_passage.shape, encoded_question.shape))
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        #logger.info('shape:{} vs {}'.format(question_passage_vector.shape, [batch_size, passage_length, encoding_dim]))
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        # concat passages belongs to the same question
        # span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_start_probs = self.apply_concat_mask_fn(span_start_logits, passage_mask, real_batch_size, softmax)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = self.apply_concat_mask_fn(span_end_logits, passage_mask, real_batch_size, softmax)

        # concat passages of the same quesiton into one passage
        if passage_mask is not None:
            passage_mask = passage_mask.view(batch_size, -1)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        if self.training == False:
            best_span, best_score = self.get_best_span(span_start_logits, span_end_logits)
        else:
            best_span, best_score = None, None

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                "best_score": best_score
                }

        # Compute the loss for training.
        #if self.training:
        if span_start is not None:
            log_start_logits = self.apply_concat_mask_fn(span_start_logits, passage_mask, real_batch_size, log_softmax).view(real_batch_size, -1)
            log_end_logits = self.apply_concat_mask_fn(span_end_logits, passage_mask, real_batch_size, log_softmax).view(real_batch_size, -1)

            loss = nll_loss(log_start_logits, span_start.squeeze(-1))
            #self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(log_end_logits, span_end.squeeze(-1))
            #self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            #self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        # if metadata is not None:
        #     output_dict['best_span_str'] = []
        #     question_tokens = []
        #     passage_tokens = []
        #     for i in range(batch_size):
        #         question_tokens.append(metadata[i]['question_tokens'])
        #         passage_tokens.append(metadata[i]['passage_tokens'])
        #         passage_str = metadata[i]['original_passage']
        #         offsets = metadata[i]['token_offsets']
        #         predicted_span = tuple(best_span[i].data.cpu().numpy())
        #         start_offset = offsets[predicted_span[0]][0]
        #         end_offset = offsets[predicted_span[1]][1]
        #         best_span_string = passage_str[start_offset:end_offset]
        #         output_dict['best_span_str'].append(best_span_string)
        #         answer_texts = metadata[i].get('answer_texts', [])
        #         if answer_texts:
        #             self._squad_metrics(best_span_string, answer_texts)
        #     output_dict['question_tokens'] = question_tokens
        #     output_dict['passage_tokens'] = passage_tokens
        return output_dict
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps)

        # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
        combined_projection = self._combined_projection(inputs)
        # split by attention dim - if values_dim > attention_dim, we will get more
        # than 3 elements returned. All of the rest are the values vector, so we
        # just concatenate them back together again below.
        queries, keys, *values = combined_projection.split(self._attention_dim, -1)
        queries = queries.contiguous()
        keys = keys.contiguous()
        values = torch.cat(values, -1).contiguous()
        # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
        values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim/num_heads))
        values_per_head = values_per_head.transpose(1, 2).contiguous()
        values_per_head = values_per_head.view(batch_size * num_heads, timesteps, int(self._values_dim/num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        queries_per_head = queries.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
        queries_per_head = queries_per_head.transpose(1, 2).contiguous()
        queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        keys_per_head = keys.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
        keys_per_head = keys_per_head.transpose(1, 2).contiguous()
        keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) / self._scale

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = last_dim_softmax(scaled_similarities, mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps))
        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps, int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Exemple #14
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                answer_impossible:torch.LongTensor = None,
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = sigmoid(span_start_logits)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        span_start_probs = sigmoid(span_start_logits)
        span_end_probs = sigmoid(span_end_logits)
        best_span = self.get_best_span(span_start_probs,span_end_probs)
        
        
        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                
                }

        # Compute the loss for training.
        if answer_impossible is not None:
            
            target_start=torch.arange(0,span_start_logits.size(1),device=span_start_logits.device,dtype=torch.long)
            target_start=target_start.squeeze(0).expand(span_start_logits.size(0),-1)==span_start
            target_start=target_start.long()*(-1*(answer_impossible-1).unsqueeze(1).expand(-1,target_start.size(-1)))
            
            target_end=torch.arange(0,span_end_logits.size(1),device=span_end_logits.device,dtype=torch.long)
            target_end=target_end.squeeze(0).expand(span_end_logits.size(0),-1)==span_end
            target_end=target_end.long()*(-1*(answer_impossible-1).unsqueeze(1).expand(-1,target_start.size(-1)))
            
            span_start_logits_for_loss=torch.stack([-1*span_start_logits,span_start_logits],dim=-1)
            
            loss = util.sequence_cross_entropy_with_logits(span_start_logits_for_loss,target_start, passage_mask)
            
            span_end_logits_for_loss=torch.stack([-1*span_end_logits,span_end_logits],dim=-1)
            loss += util.sequence_cross_entropy_with_logits(span_end_logits_for_loss,target_end, passage_mask)
                
            
            
            
            self._span_start_accuracy((span_start_logits>0).long(), target_start)
            self._span_end_accuracy((span_end_logits>0).long(), target_end)
            self._answer_impossible_accuracy(((best_span.narrow(1,0, 1)==-1)*(best_span.narrow(1,1, 1)==-1)).long(), answer_impossible)
#             self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                try:
                    if predicted_span[0]!=-1:
                        start_offset = offsets[predicted_span[0]][0]
                    else:
                        start_offset=-1
                    if predicted_span[1]!=-1:
                        end_offset = offsets[predicted_span[1]][1]
                    else:
                        end_offset=-1
                    if end_offset!=-1 and start_offset!=-1:
                        best_span_string = passage_str[start_offset:end_offset]
                    else:
                        best_span_string=""
                    output_dict['best_span_str'].append(best_span_string)
                    answer_texts = metadata[i].get('answer_texts', [])
                    if answer_texts:
                        self._squad_metrics(best_span_string, answer_texts)
                except Exception as e:
                    print(str(e))    
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Exemple #15
0
    def forward(self,
                context: Dict[str, torch.LongTensor],
                length: torch.LongTensor = None,
                repeat: torch.FloatTensor = None,
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:

        expected_dim = self.final_classifier_feedforward.get_input_dim() / 2
        dia_len = context['tokens'].size()[1]
        if expected_dim - dia_len > 0:
            padding = torch.zeros([context['tokens'].size()[0], (expected_dim - dia_len), context['tokens'].size()[2]]).long()
            context['tokens'] = torch.cat([context['tokens'], padding], dim=1)

        # context: batch_size * dials_len * sentences_len
        # embedded_context: batch_size * dials_len * sentences_len * emb_dim
        embedded_context = self.text_field_embedder(context)
        # utterances_mask: batch_size * dials_len * sentences_len
        utterances_mask = get_text_field_mask(context, 1).float()
        # encoded_utterances: batch_size * dials_len * emb_dim
        encoded_utterances = self.utterances_encoder(embedded_context, utterances_mask)
        # embedded_context: batch_size * (dials_len - 1) * emb_dim
        embedded_context = encoded_utterances[:, :-1, :]

        # embedded_response: batch_size * (dials_len - 1) * emb_dim
        embedded_response = encoded_utterances[:, 1:, :]
        # response_mask: batch_size * (dials_len - 1)
        response_mask = get_text_field_mask(context).float()[:, 1:]
        # context_mask: batch_size * (dials_len - 1)
        context_mask = get_text_field_mask(context).float()[:, :-1]

        projected_context = self.attend_feedforward(embedded_context)
        projected_response = self.attend_feedforward(embedded_response)

        # similarity_matrix: batch_size * (dials_len - 1) * (dials_len - 1)
        similarity_matrix = self.matrix_attention(projected_context, projected_response)

        # c2r_attention: batch * (dials_len - 1) * (dials_len - 1)
        c2r_attention = last_dim_softmax(similarity_matrix, response_mask)
        # attended_response: batch * (dials_len - 1) * emb_dim
        attended_response = weighted_sum(embedded_response, c2r_attention)

        # r2c_attention: batch * (dials_len - 1) * (dials_len - 1)
        r2c_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), context_mask)
        # attended_context: batch * (dials_len - 1) * emb_dim
        attended_context = weighted_sum(embedded_context, r2c_attention)


        # context_compare_input: batch * (dials_len - 1) * (emb_dim + emb_dim)
        context_compare_input = torch.cat([embedded_context, attended_response], dim=-1)
        # response_compare_input: batch * (dials_len - 1) * (emb_dim + emb_dim)
        response_compare_input = torch.cat([embedded_response, attended_context], dim=-1)

        # compared_context: batch * (dials_len - 1) * emb_dim
        compared_context = self.compare_feedforward(context_compare_input)
        compared_context = compared_context * context_mask.unsqueeze(-1)

        # compared_response: batch * (dials_len - 1) * emb_dim
        compared_response = self.compare_feedforward(response_compare_input)
        compared_response = compared_response * response_mask.unsqueeze(-1)

        # aggregate_input: batch * (dials_len - 1) * (compare_context_dim + compared_response_dim)
        aggregate_input = torch.cat([compared_context, compared_response], dim=-1)

        # class_logits & class_probs:  batch * (dials_len - 1) * 2
        class_logits = self.classifier_feedforward(aggregate_input)
        class_probs = F.softmax(class_logits, dim=-1).reshape(class_logits.size()[0], -1)
        length_tensor = torch.FloatTensor(length).reshape(-1, 1)
        repeat_tensor = torch.FloatTensor(repeat).reshape(-1, 1)
        class_probs = torch.cat([class_probs, length_tensor, repeat_tensor], dim=1)

        full_logits = self.final_classifier_feedforward(class_probs)
        full_probs = F.softmax(full_logits, dim=-1)
        output_dict = {"class_logits": full_logits, "class_probabilities": full_probs}

        if label is not None:
            loss = self.loss(full_logits, label.squeeze(-1))
            for metric in self.metrics.values():
                metric(full_logits, label.squeeze(-1))
            output_dict['loss'] = loss

        return output_dict
    def forward(
            self,  # type: ignore
            tokens,
            spans,
            metadata,
            pos_tags=None,
            span_labels=None):
        # pylint: disable=arguments-differ
        u"""
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError(
                u"Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)

        span_representations = self.span_extractor(encoded_text, spans, mask,
                                                   span_mask)

        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)

        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
            u"class_probabilities": class_probabilities,
            u"spans": spans,
            u"tokens": [meta[u"tokens"] for meta in metadata],
            u"pos_tags": [meta.get(u"pos_tags") for meta in metadata],
            u"num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels,
                                                      span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict[u"loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get(u"gold_tree") for meta in metadata]
        if all(batch_gold_trees
               ) and self._evalb_score is not None and not self.training:
            gold_pos_tags = [
                list(izip(*tree.pos()))[1] for tree in batch_gold_trees
            ]
            predicted_trees = self.construct_trees(
                class_probabilities.cpu().data,
                spans.cpu().data, num_spans.data, output_dict[u"tokens"],
                gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Exemple #17
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                pos_tags: Dict[str, torch.LongTensor] = None,
                span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
                "class_probabilities": class_probabilities,
                "spans": spans,
                "tokens": [meta["tokens"] for meta in metadata],
                "pos_tags": [meta.get("pos_tags") for meta in metadata],
                "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
                                              for tree in batch_gold_trees]
            predicted_trees = self.construct_trees(class_probabilities.cpu().data,
                                                   spans.cpu().data,
                                                   num_spans.data,
                                                   output_dict["tokens"],
                                                   gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Exemple #19
0
    def _create_attended_span_representations(
            self, head_scores: torch.FloatTensor,
            text_embeddings: torch.FloatTensor, span_ends: torch.IntTensor,
            span_widths: torch.IntTensor) -> torch.FloatTensor:
        """
        Given a tensor of unnormalized attention scores for each word in the document, compute
        distributions over every span with respect to these scores by normalising the headedness
        scores for words inside the span.

        Given these headedness distributions over every span, weight the corresponding vector
        representations of the words in the span by this distribution, returning a weighted
        representation of each span.

        Parameters
        ----------
        head_scores : ``torch.FloatTensor``, required.
            Unnormalized headedness scores for every word. This score is shared for every
            candidate. The only way in which the headedness scores differ over different
            spans is in the set of words over which they are normalized.
        text_embeddings: ``torch.FloatTensor``, required.
            The embeddings with shape  (batch_size, document_length, embedding_size)
            over which we are computing a weighted sum.
        span_ends: ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 1), representing the end indices
            of each span.
        span_widths : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 1) representing the width of each
            span candidates.
        Returns
        -------
        attended_text_embeddings : ``torch.FloatTensor``
            A tensor of shape (batch_size, num_spans, embedding_dim) - the result of
            applying attention over all words within each candidate span.
        """
        # Shape: (1, 1, max_span_width)
        max_span_range_indices = util.get_range_vector(
            self._max_span_width, text_embeddings.is_cuda).view(1, 1, -1)

        # Shape: (batch_size, num_spans, max_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the document
        # are of a smaller width than max_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        # Spans
        span_indices = F.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, text_embeddings.size(1))

        # Shape: (batch_size, num_spans, max_span_width, embedding_dim)
        span_text_embeddings = util.batched_index_select(
            text_embeddings, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_span_width)
        span_head_scores = util.batched_index_select(
            head_scores, span_indices, flat_span_indices).squeeze(-1)

        # Shape: (batch_size, num_spans, max_span_width)
        span_head_weights = util.last_dim_softmax(span_head_scores, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised head score distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_text_embeddings,
                                                     span_head_weights)

        return attended_text_embeddings
Exemple #20
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        #import pdb; pdb.set_trace()
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        # answer_len for masking
        answer_len = [len(elem['answer_texts'])
                      for elem in metadata] if metadata is not None else []
        if answer_len:
            mask = torch.zeros((batch_size, max(answer_len), 2)).long()
            for index, length in enumerate(answer_len):
                mask[index, :length] = 1
        else:
            mask = None

        best_span, top_span_logits = self.get_best_span(
            span_start_logits, span_end_logits, answer_len)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            span_start = span_start.squeeze(-1)  #batch X max_answer_L
            span_end = span_end.squeeze(-1)  #batch X max_answer_L

            # a batch_size x passage_length tensor with 1's indicating right
            # answer at that position/index
            span_start_pos = torch.zeros((batch_size, passage_length))
            span_end_pos = torch.zeros((batch_size, passage_length))

            for row_id, row in enumerate(span_start):
                for span_index in row:
                    span_index = span_index.data[0]
                    if span_index == -1:
                        break
                    span_start_pos[row_id][span_index] = 1

            for row_id, row in enumerate(span_end):
                for span_index in row:
                    span_index = span_index.data[0]
                    if span_index == -1:
                        break
                    span_end_pos[row_id][span_index] = 1

            span_start_ground = to_variable(
                span_start_pos)  # batch x passage_len
            span_end_ground = to_variable(span_end_pos)  # batch x passage_len

            # at this point, we have a 2 - 2d matrix for start, end respectively
            # each matrix has the index of the right answer set to 1

            flattened_start_pred = flatten_answer(span_start_logits,
                                                  passage_mask)
            flattened_end_pred = flatten_answer(span_end_logits, passage_mask)
            flattened_start_ground = flatten_answer(span_start_ground,
                                                    passage_mask)
            flattened_end_ground = flatten_answer(span_end_ground,
                                                  passage_mask)

            loss = binary_cross_entropy_with_logits(flattened_start_pred,
                                                    flattened_start_ground)
            loss += binary_cross_entropy_with_logits(flattened_end_pred,
                                                     flattened_end_ground)
            """
            #TODO for better reporting only
            self._span_start_accuracy(flattened_start_pred, flattened_start_ground)
            self._span_end_accuracy(flattened_end_pred, flattened_end_ground)
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1), mask)
            """
            """
            # OLD CODE - ONLY REFERENCE
            # TODO answer padding needs to be ignored
            step = 0
            span_start_1D = span_start[ : , step:step + 1] #batch X 1 
            span_end_1D = span_end[ : , step:step + 1] #batch X 1 
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start_1D.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start_1D.squeeze(-1)) #TODO
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end_1D.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end_1D.squeeze(-1)) #TODO
            # self._span_accuracy(best_span, torch.stack([span_start_1D, span_end_1D], -1))#TODO

            for step in range(1, span_start.size(1)):
                span_start_1D = span_start[ : , step:step + 1] #batch X 1 
                span_end_1D = span_end[ : , step:step + 1] #batch X 1 
                loss += nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start_1D.squeeze(-1), ignore_index=-1)
                self._span_start_accuracy(span_start_logits, span_start_1D.squeeze(-1)) #TODO
                loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end_1D.squeeze(-1), ignore_index=-1)
                self._span_end_accuracy(span_end_logits, span_end_1D.squeeze(-1)) #TODO
                # self._span_accuracy(best_span, torch.stack([span_start_1D, span_end_1D], -1))#TODO
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1), mask)
            """
            output_dict["loss"] = loss

        pscores = top_span_logits[:, :, 0]  # 40 X 12
        span_starts = top_span_logits[:, :, 1]  # 40 X 12
        span_ends = top_span_logits[:, :, 2]  # 40 X 12
        best_span_starts = best_span[:, :, 0]  # 40 X 12 # to check for -1

        lr_list = []  #TODO: Place this is in the right spot
        label = 0
        pscore = 0

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                best_span_strings = []
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_spans = tuple(best_span[i].data.cpu().numpy())
                for predicted_span in predicted_spans:
                    if predicted_span[0] == -1:
                        break
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    best_span_string = passage_str[start_offset:end_offset]
                    best_span_strings.append(best_span_string)
                output_dict['best_span_str'].append(best_span_strings)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_strings, answer_texts)
                for j in range(span_starts.shape[1]):
                    pscore = pscores.data[i][j]
                    if best_span_starts.data[i][j] != -1:
                        label = 1
                    else:
                        label = 0

                    question_comp = metadata[i]['qID'].split(',')[1].replace(
                        '@', '-'
                    )  #TODO: COnsidering only 1 question entity, what if no entity
                    answer_comp = passage_str[int(span_starts.data[i][j]):int(
                        span_ends.data[i]
                        [j])]  #TODO: this will need some further processing
                    dijkstra_comp = metadata[i]['dijkstra']
                    import pdb
                    pdb.set_trace()
                    dscore = dijkstra_comp[question_comp][
                        answer_comp] if question_comp in dijkstra_comp and answer_comp in dijkstra_comp[
                            question_comp] else None
                    if dscore is not None:
                        lr_list.append((pscore, dscore, label))

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens

        return output_dict
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
Exemple #22
0
    def forward(self, s1, s2):
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        s1 : Dict[str, torch.LongTensor]
            From a ``TextField``.
        s2 : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this s2 contains the answer to the
            s1, and predicts the beginning and ending positions of the answer within the
            s2.

        Returns
        -------
        pair_rep : torch.FloatTensor?
            Tensor representing the final output of the BiDAF model
            to be plugged into the next module

        """
        s1_embs = self._highway_layer(self._text_field_embedder(s1))
        s2_embs = self._highway_layer(self._text_field_embedder(s2))
        if self._elmo is not None:
            s1_elmo_embs = self._elmo(s1['elmo'])
            s2_elmo_embs = self._elmo(s2['elmo'])
            if "words" in s1:
                s1_embs = torch.cat([s1_embs, s1_elmo_embs['elmo_representations'][0]], dim=-1)
                s2_embs = torch.cat([s2_embs, s2_elmo_embs['elmo_representations'][0]], dim=-1)
            else:
                s1_embs = s1_elmo_embs['elmo_representations'][0]
                s2_embs = s2_elmo_embs['elmo_representations'][0]
        if self._cove is not None:
            s1_lens = torch.ne(s1['words'], self.pad_idx).long().sum(dim=-1).data
            s2_lens = torch.ne(s2['words'], self.pad_idx).long().sum(dim=-1).data
            s1_cove_embs = self._cove(s1['words'], s1_lens)
            s1_embs = torch.cat([s1_embs, s1_cove_embs], dim=-1)
            s2_cove_embs = self._cove(s2['words'], s2_lens)
            s2_embs = torch.cat([s2_embs, s2_cove_embs], dim=-1)
        s1_embs = self._dropout(s1_embs)
        s2_embs = self._dropout(s2_embs)

        if self._mask_lstms:
            s1_mask = s1_lstm_mask = util.get_text_field_mask(s1).float()
            s2_mask = s2_lstm_mask = util.get_text_field_mask(s2).float()
            s1_mask_2 = util.get_text_field_mask(s1).float()
            s2_mask_2 = util.get_text_field_mask(s2).float()
        else:
            s1_lstm_mask, s2_lstm_mask, s2_lstm_mask_2 = None, None, None

        s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask)
        s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask)

        # Similarity matrix
        # Shape: (batch_size, s2_length, s1_length)
        similarity_mat = self._matrix_attention(s2_enc, s1_enc)

        # s2 representation
        # Shape: (batch_size, s2_length, s1_length)
        s2_s1_attention = util.last_dim_softmax(similarity_mat, s1_mask)
        # Shape: (batch_size, s2_length, encoding_dim)
        s2_s1_vectors = util.weighted_sum(s1_enc, s2_s1_attention)
        # batch_size, seq_len, 4*enc_dim
        s2_w_context = torch.cat([s2_enc, s2_s1_vectors], 2)
        # s1 representation, using same attn method as for the s2 representation
        s1_s2_attention = util.last_dim_softmax(similarity_mat.transpose(1, 2).contiguous(), s2_mask)
        # Shape: (batch_size, s1_length, encoding_dim)
        s1_s2_vectors = util.weighted_sum(s2_enc, s1_s2_attention)
        s1_w_context = torch.cat([s1_enc, s1_s2_vectors], 2)
        if self._elmo is not None and self._deep_elmo:
            s1_w_context = torch.cat([s1_w_context, s1_elmo_embs['elmo_representations'][1]], dim=-1)
            s2_w_context = torch.cat([s2_w_context, s2_elmo_embs['elmo_representations'][1]], dim=-1)
        s1_w_context = self._dropout(s1_w_context)
        s2_w_context = self._dropout(s2_w_context)

        modeled_s2 = self._dropout(self._modeling_layer(s2_w_context, s2_lstm_mask))
        s2_mask_2 = s2_mask_2.unsqueeze(dim=-1)
        modeled_s2.data.masked_fill_(1 - s2_mask_2.byte().data, -float('inf'))
        s2_enc_attn = modeled_s2.max(dim=1)[0]
        modeled_s1 = self._dropout(self._modeling_layer(s1_w_context, s1_lstm_mask))
        s1_mask_2 = s1_mask_2.unsqueeze(dim=-1)
        modeled_s1.data.masked_fill_(1 - s1_mask_2.byte().data, -float('inf'))
        s1_enc_attn = modeled_s1.max(dim=1)[0]

        return torch.cat([s1_enc_attn, s2_enc_attn, torch.abs(s1_enc_attn - s2_enc_attn),
                          s1_enc_attn * s2_enc_attn], 1)
Exemple #23
0
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        choices_list: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
        facts_list: Dict[str, torch.LongTensor] = None,
        question2facts_map: Dict[str, torch.LongTensor] = None,
        choice2facts_map: Dict[str, torch.LongTensor] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``
        choices_list : Dict[str, torch.LongTensor]
            From a ``List[TextField]``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        encoded_question_aggregated, question_last_hidden_states = embedd_encode_and_aggregate_text_field(
            question,
            self._text_field_embedder,
            self._embeddings_dropout,
            self._question_encoder,
            self._question_aggregate,
            get_last_states=(self._choices_init_from_question_states or
                             self._facts_init_from_question_states))  # bs, hs

        encoded_choices_aggregated = embedd_encode_and_aggregate_list_text_field(
            choices_list,
            self._text_field_embedder,
            self._embeddings_dropout,
            self._choice_encoder,
            self._choice_aggregate,
            init_hidden_states=question_last_hidden_states
            if self._choices_init_from_question_states else
            None)  # # bs, choices, hs

        bs = encoded_question_aggregated.shape[0]
        choices_cnt = encoded_choices_aggregated.shape[1]

        if self._use_knowledge:
            # encode facts
            encoded_facts_aggregated = embedd_encode_and_aggregate_list_text_field(
                facts_list,
                self._text_field_embedder,
                self._embeddings_dropout,
                self._facts_encoder,
                self._facts_aggregate,
                init_hidden_states=question_last_hidden_states
                if self._facts_init_from_question_states else
                None)  # # bs, choices, hs

            facts_aggregated_mask = get_text_field_mask(
                facts_list, num_wrapping_dims=0).float()

            facts_aggregated_mask_q_to_facts = facts_aggregated_mask
            if self._use_ctx2facts_retrieval_map_as_mask and self.training:
                facts_aggregated_mask_q_to_facts = facts_aggregated_mask_q_to_facts * question2facts_map
                facts_aggregated_mask_q_to_facts = (
                    facts_aggregated_mask_q_to_facts > 0.00).float()

            facts_cnt = encoded_facts_aggregated.shape[1]

            # question to knowledge

            q_to_facts_att = self._matrix_attention_text_to_facts(
                encoded_question_aggregated.unsqueeze(1),
                encoded_facts_aggregated).view([bs, facts_cnt])
            q_to_facts_att_softmax = util.last_dim_softmax(
                q_to_facts_att, facts_aggregated_mask_q_to_facts)
            q_to_facts_weighted_sum = util.weighted_sum(
                encoded_facts_aggregated, q_to_facts_att_softmax)

            assert encoded_question_aggregated.shape == q_to_facts_weighted_sum.shape

            # choices to knowledge

            choices_to_facts_att = self._matrix_attention_text_to_facts(
                encoded_choices_aggregated,
                encoded_facts_aggregated).view([bs, choices_cnt, facts_cnt
                                                ])  # bs, choices, facts

            facts_aggregated_mask_ch_to_facts = facts_aggregated_mask.unsqueeze(
                1).expand(choices_to_facts_att.shape)
            if self._use_ctx2facts_retrieval_map_as_mask and self.training:
                facts_aggregated_mask_ch_to_facts = facts_aggregated_mask_ch_to_facts * choice2facts_map
                facts_aggregated_mask_ch_to_facts = (
                    facts_aggregated_mask_ch_to_facts > 0.00).float()

            choices_to_facts_att_softmax = util.last_dim_softmax(
                choices_to_facts_att, facts_aggregated_mask_ch_to_facts)
            choices_to_facts_weighted_sum = util.weighted_sum(
                encoded_facts_aggregated, choices_to_facts_att_softmax)

            assert encoded_choices_aggregated.shape == choices_to_facts_weighted_sum.shape

            # combine with knowledge
            question_ctx_plus_know = self._text_plus_knowledge_repr_funciton(
                q_to_facts_weighted_sum, encoded_question_aggregated)
            choices_ctx_plus_know = self._text_plus_knowledge_repr_funciton(
                choices_to_facts_weighted_sum, encoded_choices_aggregated)

            # question to choices interactions
            q_to_choices_att_list = []

            q_to_choices_combined_att = attention_interaction_combinations(
                quest_ctx=encoded_question_aggregated,
                choices_ctx=encoded_choices_aggregated,
                quest_ctx_plus_kn=question_ctx_plus_know,
                choices_ctx_plus_kn=choices_ctx_plus_know,
                quest_kn=q_to_facts_weighted_sum,
                choices_kn=choices_to_facts_weighted_sum,
                inter_to_include=self._know_interactions,
                att_matrix_mappings=self._matrix_attention_question_to_choice)

            # q_to_choices_att = self._matrix_attention_question_to_choice(encoded_question_aggregated.unsqueeze(1),
            #                                                              encoded_choices_aggregated).squeeze()

            if q_to_choices_combined_att.shape[-1] > 1:
                q_to_choices_att = self._know_aggregate_feedforward(
                    q_to_choices_combined_att).squeeze(-1)
            else:
                q_to_choices_att = q_to_choices_combined_att.squeeze(-1)
        else:
            # dont use knowledge
            q_to_choices_att = self._matrix_attention_question_to_choice(
                encoded_question_aggregated.unsqueeze(1),
                encoded_choices_aggregated).squeeze()
            # print("No knowledge is used")
        label_logits = q_to_choices_att
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits.data.tolist(),
            "label_probs": label_probs.data.tolist()
        }

        if self._return_question_to_choices_att \
                or self._return_question_to_facts_att \
                or self._return_choice_to_facts_att:

            attentions_dict = {}
            know_interactions_weights_dict = {}
            if self._return_question_to_choices_att:
                # Keep also the interaction weights used for the final prediction

                # attentions
                att_to_export_q_to_ch = {}
                q_to_ch_raw_type = "__".join(["ctx", "ctx"])
                if self._use_knowledge:

                    try:
                        # Get interaction weights.
                        # These are currently static but can be replaced with dynamic gating later.
                        know_interactions_weights = self._know_aggregate_feedforward._linear_layers[
                            0].weight.data.tolist()[0]
                    except:
                        know_interactions_weights = [0.0] * len(
                            self._know_interactions)

                    q_to_choices_combined_att_transposed = torch.nn.functional.softmax(
                        q_to_choices_combined_att.permute([2, 0, 1]), dim=-1)

                    # Get the interaction attentions
                    for inter_id, interaction in enumerate(
                            self._know_interactions):
                        interaction_name = "__".join(interaction)
                        att_to_export_q_to_ch[
                            interaction_name] = q_to_choices_combined_att_transposed[
                                inter_id].data.tolist()
                        know_interactions_weights_dict[
                            interaction_name] = know_interactions_weights[
                                inter_id]

                    # In this case "ctx__ctx" is not included in the knowledge interactions for the final prediction,
                    # so we set the weight to 0.0
                    if q_to_ch_raw_type not in know_interactions_weights_dict:
                        know_interactions_weights_dict[q_to_ch_raw_type] = 0.0
                else:
                    # In this case we do not use multiple interactions and the only prediction is for ctx__ctx
                    if q_to_ch_raw_type not in know_interactions_weights_dict:
                        know_interactions_weights_dict[q_to_ch_raw_type] = 1.0

                if not q_to_ch_raw_type in att_to_export_q_to_ch:
                    q_to_ch_att_ctx_ctx = self._matrix_attention_question_to_choice(
                        encoded_question_aggregated.unsqueeze(1),
                        encoded_choices_aggregated).squeeze()
                    q_to_ch_att_ctx_ctx = torch.nn.functional.softmax(
                        q_to_ch_att_ctx_ctx, dim=-1)
                    att_to_export_q_to_ch[
                        q_to_ch_raw_type] = q_to_ch_att_ctx_ctx.data.tolist()

                att_to_export_q_to_ch["final"] = label_probs.data.tolist()
                attentions_dict["att_q_to_ch"] = att_to_export_q_to_ch

            if self._use_knowledge:
                if self._return_question_to_facts_att:
                    att_to_export_q_to_f = {}

                    # TO DO: Update when more sources are added
                    att_to_export_q_to_f[
                        "src1"] = q_to_facts_att_softmax.data.tolist()
                    attentions_dict["att_q_to_f"] = att_to_export_q_to_f

                if self._return_choice_to_facts_att:
                    att_to_export_ch_to_f = {}

                    # TO DO: Update when more sources are added
                    att_to_export_ch_to_f[
                        "src1"] = choices_to_facts_att_softmax.data.tolist()
                    attentions_dict["att_ch_to_f"] = att_to_export_ch_to_f

            output_dict["attentions"] = attentions_dict
            output_dict["know_inter_weights"] = know_interactions_weights_dict

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.long().view([bs]))
            output_dict["loss"] = loss

        return output_dict
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``.
        label : torch.LongTensor, optional (default = None)
            A variable representing the label for each instance in the batch.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a
            distribution over the label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        text_mask = util.get_text_field_mask(tokens).float()
        # Pop elmo tokens, since elmo embedder should not be present.
        elmo_tokens = tokens.pop("elmo", None)
        embedded_text = self._text_field_embedder(tokens)

        # Add the "elmo" key back to "tokens" if not None, since the tests and the
        # subsequent training epochs rely not being modified during forward()
        if elmo_tokens is not None:
            tokens["elmo"] = elmo_tokens

        # Create ELMo embeddings if applicable
        if self._elmo:
            if elmo_tokens is not None:
                elmo_representations = self._elmo(elmo_tokens)["elmo_representations"]
                # Pop from the end is more performant with list
                if self._use_integrator_output_elmo:
                    integrator_output_elmo = elmo_representations.pop()
                if self._use_input_elmo:
                    input_elmo = elmo_representations.pop()
                assert not elmo_representations
            else:
                raise ConfigurationError(
                        "Model was built to use Elmo, but input text is not tokenized for Elmo.")

        if self._use_input_elmo:
            embedded_text = torch.cat([embedded_text, input_elmo], dim=-1)

        dropped_embedded_text = self._embedding_dropout(embedded_text)
        pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text)
        encoded_tokens = self._encoder(pre_encoded_text, text_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.last_dim_softmax(attention_logits, text_mask)
        encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([encoded_tokens,
                                      encoded_tokens - encoded_text,
                                      encoded_tokens * encoded_text], 2)
        integrated_encodings = self._integrator(integrator_input, text_mask)

        # Concatenate ELMo representations to integrated_encodings if specified
        if self._use_integrator_output_elmo:
            integrated_encodings = torch.cat([integrated_encodings,
                                              integrator_output_elmo], dim=-1)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
                integrated_encodings, text_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
                integrated_encodings, text_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(text_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
                integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits, text_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights)

        pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(pooled_representations)

        logits = self._output_layer(pooled_representations_dropped)
        class_probabilities = F.softmax(logits, dim=-1)

        output_dict = {'logits': logits, 'class_probabilities': class_probabilities}
        if label is not None:
            loss = self.loss(logits, label)
            for metric in self.metrics.values():
                metric(logits, label)
            output_dict["loss"] = loss

        return output_dict
Exemple #25
0
    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: torch.Tensor,
            mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, hidden_dim = inputs.size()
        if mask is None:
            mask = Variable(inputs.data.new(batch_size, timesteps).fill_(1.0))

        # Treat the queries, keys and values each as a ``num_heads`` size batch.
        # shape (num_heads, batch_size * timesteps, hidden_dim)
        inputs_per_head = inputs.repeat(num_heads, 1,
                                        1).view(num_heads,
                                                batch_size * timesteps,
                                                hidden_dim)
        # Do the projections for all the heads at once.
        # Then reshape the result as though it had a
        # (num_heads * batch_size) sized batch.
        queries_per_head = torch.bmm(inputs_per_head, self._query_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        queries_per_head = queries_per_head.view(num_heads * batch_size,
                                                 timesteps,
                                                 self._attention_dim)

        keys_per_head = torch.bmm(inputs_per_head, self._key_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        keys_per_head = keys_per_head.view(num_heads * batch_size, timesteps,
                                           self._attention_dim)

        values_per_head = torch.bmm(inputs_per_head, self._value_projections)
        # shape (num_heads * batch_size, timesteps, attention_dim)
        values_per_head = values_per_head.view(num_heads * batch_size,
                                               timesteps, self._values_dim)

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(
            queries_per_head, keys_per_head.transpose(1, 2)) / self._scale

        # Masking should go here
        causality_mask = subsequent_mask(timesteps).cuda()
        masked_scaled_similarities = scaled_similarities.masked_fill(
            causality_mask == 0, -1e9)

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = last_dim_softmax(masked_scaled_similarities,
                                     mask.repeat(num_heads, 1))
        attention = self._attention_dropout(attention)
        # This is doing the following batch-wise matrix multiplication:
        # (num_heads * batch_size, timesteps, timesteps) *
        # (num_heads * batch_size, timesteps, values_dim)
        # which is equivalent to a weighted sum of the values with respect to
        # the attention distributions for each element in the num_heads * batch_size
        # dimension.
        # shape (num_heads * batch_size, timesteps, values_dim)
        outputs = torch.bmm(attention, values_per_head)

        # Reshape back to original shape (batch_size, timesteps, num_heads * values_dim)
        # Note that we _cannot_ use a reshape here, because this tensor was created
        # with num_heads being the first dimension, so reshaping naively would not
        # throw an error, but give an incorrect result.
        outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Exemple #26
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],
                dim=-1
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True
        )

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Exemple #27
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            spans: torch.LongTensor,
            span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        span_labels : torch.LongTensor, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        logits : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing unnormalised log probabilities of the label classes for each span.
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask,
                                                   span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
            "class_probabilities": class_probabilities,
            "spans": spans,
            # TODO(Mark): This relies on having tokens represented with a SingleIdTokenIndexer...
            "tokens": tokens["tokens"],
            "token_mask": mask
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels,
                                                      span_mask)
            for metric in self.metrics.values():
                metric(logits, span_labels, span_mask)
            output_dict["loss"] = loss

        return output_dict
Exemple #28
0
    def forward(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        choices_list: Dict[str, torch.LongTensor],
        label: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``
        choices_list : Dict[str, torch.LongTensor]
            From a ``List[TextField]``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # encoded_choices_aggregated = embedd_encode_and_aggregate_list_text_field(choices_list,
        #                                                                          self._text_field_embedder,
        #                                                                          self._embeddings_dropout,
        #                                                                          self._choice_encoder,
        #                                                                          self._choice_aggregate)  # # bs, choices, hs
        #
        # encoded_question_aggregated, _ = embedd_encode_and_aggregate_text_field(question, self._text_field_embedder,
        #                                                                         self._embeddings_dropout,
        #                                                                         self._question_encoder,
        #                                                                         self._question_aggregate,
        #                                                                         get_last_states=False)  # bs, hs
        #
        # q_to_choices_att = self._matrix_attention_question_to_choice(encoded_question_aggregated.unsqueeze(1),
        #                                                              encoded_choices_aggregated).squeeze()
        #
        # label_logits = q_to_choices_att
        # label_probs = torch.nn.functional.softmax(label_logits, dim=-1)
        #
        # output_dict = {"label_logits": label_logits, "label_probs": label_probs}
        #
        # if label is not None:
        #     loss = self._loss(label_logits, label.long().view(-1))
        #     self._accuracy(label_logits, label.squeeze(-1))
        #     output_dict["loss"] = loss

        embedded_question = self._text_field_embedder(question)
        embedded_choices = self._text_field_embedder(choices_list)
        question_mask = get_text_field_mask(question).float()
        choices_mask_3d = get_text_field_mask(choices_list,
                                              num_wrapping_dims=1).float()

        # apply dropout for LSTM
        if self._embeddings_dropout:
            embedded_question = self._embeddings_dropout(embedded_question)
            embedded_choices = self._embeddings_dropout(embedded_choices)

        batch_size, choices_cnt, choices_tokens_cnt, emb_size = tuple(
            embedded_choices.shape)
        choices_mask_flattened = choices_mask_3d.view(
            [batch_size * choices_cnt, choices_tokens_cnt])

        # Shape: (batch_size * choices_cnt, choices_tokens_cnt, embedding_size)
        embedded_choices_flattened = embedded_choices.view(
            [batch_size * choices_cnt, choices_tokens_cnt, -1])

        # encode question and choices

        # Shape: (batch_size, question_tokens_cnt, encoder_out_size)
        encoded_question = self._question_encoder(embedded_question,
                                                  question_mask)
        question_tokens_cnt = encoded_question.shape[1]
        encoder_out_size = encoded_question.shape[2]

        # tile to choices tokens
        # Shape: (batch_size, choices_cnt, question_tokens_cnt, encoder_out_size)
        encoded_question = encoded_question.unsqueeze(1).expand(
            batch_size, choices_cnt, question_tokens_cnt,
            encoder_out_size).contiguous()

        # Shape: (batch_size * choices_cnt, question_tokens_cnt, encoder_out_size)
        encoded_question = encoded_question.view(
            [batch_size * choices_cnt, question_tokens_cnt,
             encoder_out_size]).contiguous()

        # tile to choices tokens
        # Shape: (batch_size, choices_cnt, question_length)
        question_mask = question_mask.unsqueeze(1).expand(
            batch_size, choices_cnt, question_tokens_cnt).contiguous()

        # Shape: (batch_size * choices_cnt, question_length)
        question_mask = question_mask.view(
            [batch_size * choices_cnt, question_tokens_cnt]).contiguous()

        # encode choices
        # Shape: (batch_size * choices_cnt, choices_tokens_cnt, encoder_out_size)
        encoded_choices = self._choice_encoder(embedded_choices_flattened,
                                               choices_mask_flattened)
        choices_mask = choices_mask_flattened

        # Shape: (batch_size * choices_cnt, question_length, choices_length)
        similarity_matrix = self._matrix_attention(encoded_question,
                                                   encoded_choices)

        # Shape: (batch_size, question_length, choices_length)
        p2h_attention = last_dim_softmax(similarity_matrix, choices_mask)
        # Shape: (batch_size, question_length, embedding_dim)
        attended_choices = weighted_sum(encoded_choices, p2h_attention)

        # Shape: (batch_size, choices_length, question_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), question_mask)
        # Shape: (batch_size, choices_length, embedding_dim)
        attended_question = weighted_sum(encoded_question, h2p_attention)

        # the "enhancement" layer
        question_enhanced = torch.cat([
            encoded_question, attended_choices, encoded_question -
            attended_choices, encoded_question * attended_choices
        ],
                                      dim=-1)
        choices_enhanced = torch.cat([
            encoded_choices, attended_question, encoded_choices -
            attended_question, encoded_choices * attended_question
        ],
                                     dim=-1)

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_question = self._projection_feedforward(
            question_enhanced)
        projected_enhanced_choices = self._projection_feedforward(
            choices_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_question = self.rnn_input_dropout(
                projected_enhanced_question)
            projected_enhanced_choices = self.rnn_input_dropout(
                projected_enhanced_choices)
        v_ai = self._inference_encoder(projected_enhanced_question,
                                       question_mask)
        v_bi = self._inference_encoder(projected_enhanced_choices,
                                       choices_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(v_ai, question_mask.unsqueeze(-1),
                                           -1e7).max(dim=1)
        v_b_max, _ = replace_masked_values(v_bi, choices_mask.unsqueeze(-1),
                                           -1e7).max(dim=1)

        v_a_avg = torch.sum(v_ai * question_mask.unsqueeze(-1),
                            dim=1) / torch.sum(question_mask, 1, keepdim=True)
        v_b_avg = torch.sum(v_bi * choices_mask.unsqueeze(-1),
                            dim=1) / torch.sum(choices_mask, 1, keepdim=True)

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_logits = label_logits.view([batch_size, choices_cnt])
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,
                                                   projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs,
            "h2p_attention": h2p_attention,
            "p2h_attention": p2h_attention
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
Exemple #30
0
    def forward(
            self,  # pylint: disable=arguments-differ
            inputs: torch.Tensor,
            mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = Variable(inputs.data.new(batch_size, timesteps).fill_(1.0))

        # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
        combined_projection = self._combined_projection(inputs)

        # split by attention dim - if values_dim > attention_dim, we will get more
        # than 3 elements returned. All of the rest are the values vector, so we
        # just concatenate them back together again below.
        queries, keys, *values = combined_projection.split(
            self._attention_dim, -1)
        queries = queries.contiguous()
        keys = keys.contiguous()
        values = torch.cat(values, -1).contiguous()
        # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
        values_per_head = values.view(batch_size, timesteps, num_heads,
                                      int(self._values_dim / num_heads))
        values_per_head = values_per_head.transpose(1, 2).contiguous()
        values_per_head = values_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._values_dim / num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        queries_per_head = queries.view(batch_size, timesteps, num_heads,
                                        int(self._attention_dim / num_heads))
        queries_per_head = queries_per_head.transpose(1, 2).contiguous()
        queries_per_head = queries_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._attention_dim / num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        keys_per_head = keys.view(batch_size, timesteps, num_heads,
                                  int(self._attention_dim / num_heads))
        keys_per_head = keys_per_head.transpose(1, 2).contiguous()
        keys_per_head = keys_per_head.view(
            batch_size * num_heads, timesteps,
            int(self._attention_dim / num_heads))

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(
            queries_per_head, keys_per_head.transpose(1, 2)) / self._scale

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = last_dim_softmax(scaled_similarities,
                                     mask.repeat(num_heads, 1))
        attention = self._attention_dropout(attention)
        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)
        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # Note that we _cannot_ use a reshape here, because this tensor was created
        # with num_heads being the first dimension, so reshaping naively would not
        # throw an error, but give an incorrect result.
        outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=-1)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Exemple #31
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.  The ending position is `exclusive`, so our
            :class:`~allennlp.data.dataset_readers.SquadReader` adds a special ending token to the
            end of the passage, to allow for the last token to be included in the answer span.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `exclusive` index.  If
            this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalised log
            probabilities of the span end position (exclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self._get_best_span(span_start_logits, span_end_logits)

        output_dict = {"span_start_logits": span_start_logits,
                       "span_start_probs": span_start_probs,
                       "span_end_logits": span_end_logits,
                       "span_end_probs": span_end_probs,
                       "best_span": best_span}
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss
        if metadata is not None and self._official_eval_dataset:
            output_dict['best_span_str'] = []
            for i in range(batch_size):
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                best_span_string = self._compute_official_metrics(metadata[i], predicted_span)  # type: ignore
                output_dict['best_span_str'].append(best_span_string)
        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.LongTensor = None,
            span_end: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        embedded_passage = self._text_field_embedder(passage)
        passage_mask = util.get_text_field_mask(passage, 1).float()
        # get params
        cuda_device = embedded_passage.get_device()
        batch_size, num_passage, passage_length, embedding_dim = embedded_passage.size(
        )
        # when training, select randomly 2 passages from 4 passages each epoch
        if self.training:
            num_passage = 2
            probs = torch.Tensor([1, 1, 1]).unsqueeze(0).expand(batch_size, 3)
            indices = torch.multinomial(probs, 1) + 1
            zeros_tensor = torch.zeros(batch_size).long()
            indices = Variable(
                torch.cat([zeros_tensor.unsqueeze(-1), indices],
                          1).cuda(cuda_device))
            embedded_passage = torch.gather(
                embedded_passage, 1,
                indices.unsqueeze(-1).unsqueeze(-1).expand(
                    batch_size, num_passage, passage_length, embedding_dim))
            passage_mask = torch.gather(
                passage_mask, 1,
                indices.unsqueeze(-1).expand(batch_size, num_passage,
                                             passage_length))
        # Shape: (batch_size*num_passage, passage_length, embedding_dim)
        embedded_passage = embedded_passage.view(-1, passage_length,
                                                 embedding_dim)
        embedded_passage = self._dropout(embedded_passage)
        # Shape: (batch_size*num_passage, passage_length)
        passage_mask = passage_mask.view(-1, passage_length)
        # Shape: (batch_size, question_length, embedding_dim)
        embedded_question = self._text_field_embedder(question)
        # Shape: (batch_size*num_passage, question_length)
        embedded_question = embedded_question.unsqueeze(1).expand(
            -1, num_passage, -1,
            -1).contiguous().view(batch_size * num_passage, -1, embedding_dim)
        embedded_question = self._dropout(embedded_question)
        # Shape: (batch_size, question_length)
        question_mask = util.get_text_field_mask(question).float()
        # Shape: (batch_size*num_passage, question_length)
        question_mask = question_mask.unsqueeze(1).expand(
            -1, num_passage, -1).contiguous().view(batch_size * num_passage,
                                                   -1)
        # lstm masks
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        # Shape: (batch_size*num_passage, -1, encoding_dim)
        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)
        # Shape: (batch_size*num_passage, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size*num_passage, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size_num_passage, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)
        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size*num_passage, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size*num_passage, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size*num_passage, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size * num_passage, passage_length, encoding_dim)

        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        question_attended_passage = relu(
            self._linear_layer(final_merged_passage))
        # attach residual self-attention layer
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        residual_passage = self._dropout(
            self._residual_encoder(self._dropout(question_attended_passage),
                                   passage_lstm_mask))
        # self-attention mask
        mask = passage_mask.resize(
            batch_size * num_passage, passage_length, 1) * passage_mask.resize(
                batch_size * num_passage, 1, passage_length)
        self_mask = Variable(
            torch.eye(passage_length,
                      passage_length).cuda(cuda_device)).resize(
                          1, passage_length, passage_length)
        mask = mask * (1 - self_mask)
        # Shape: (batch_size*num_passage, passage_length, passage_length)
        passage_self_similarity = self._self_matrix_attention(
            residual_passage, residual_passage)
        # Shape: (batch_size*num_passage, passage_length, passage_length)
        passage_self_attention = util.last_dim_softmax(passage_self_similarity,
                                                       mask)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        passage_vectors = util.weighted_sum(residual_passage,
                                            passage_self_attention)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 3)
        merged_passage = torch.cat([
            residual_passage, passage_vectors,
            residual_passage * passage_vectors
        ],
                                   dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        self_attended_passage = relu(
            self._residual_linear_layer(merged_passage))
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        mixed_passage = question_attended_passage + self_attended_passage
        # add dropout
        mixed_passage = self._dropout(mixed_passage)

        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        encoded_span_start = self._dropout(
            self._span_start_encoder(mixed_passage, passage_lstm_mask))
        span_start_logits = self._span_start_predictor(
            encoded_span_start).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 2)
        concatenated_passage = torch.cat([mixed_passage, encoded_span_start],
                                         dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(concatenated_passage, passage_lstm_mask))
        span_end_logits = self._span_end_predictor(encoded_span_end).squeeze(
            -1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        # Shape: (batch_size*num_passage, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
        }
        # when training, we need to merge all the paragraphs in the same context before computing loss
        if span_start is not None:
            num_spans = span_start.size(1)
            all_modified_start_logits = []
            all_modified_end_logits = []

            span_start_logits = span_start_logits.view(batch_size, num_passage,
                                                       passage_length).view(
                                                           batch_size, -1)
            span_end_logits = span_end_logits.view(batch_size, num_passage,
                                                   passage_length).view(
                                                       batch_size, -1)
            passage_mask = passage_mask.view(batch_size, num_passage,
                                             passage_length).view(
                                                 batch_size, -1)

            start_mask = passage_mask.clone()
            end_mask = passage_mask.clone()
            for b in range(batch_size):
                start_idxs = Variable(
                    torch.LongTensor(range(num_passage *
                                           passage_length))).cuda()
                end_idxs = Variable(
                    torch.LongTensor(range(num_passage *
                                           passage_length))).cuda()
                for i in range(1, num_spans):
                    if span_start[b, i].data[0] >= 0:
                        start_idxs[span_start[b, i].data[0]].data = start_idxs[
                            span_start[b, 0].data[0]].data
                        end_idxs[span_end[b, i].data[0]].data = end_idxs[
                            span_end[b, 0].data[0]].data
                        start_mask[b, span_start[b, i].data[0]] = 0
                        end_mask[b, span_end[b, i].data[0]] = 0
                    else:
                        break
                modified_start_logits = Variable(
                    torch.zeros(num_passage * passage_length)).cuda()
                modified_end_logits = Variable(
                    torch.zeros(num_passage * passage_length)).cuda()

                modified_start_logits.put_(start_idxs, span_start_logits[b])
                modified_end_logits.put_(end_idxs, span_end_logits[b])

                all_modified_start_logits.append(modified_start_logits)
                all_modified_end_logits.append(modified_end_logits)

            all_modified_start_logits = torch.stack(all_modified_start_logits,
                                                    dim=0)
            all_modified_end_logits = torch.stack(all_modified_end_logits,
                                                  dim=0)

            loss = nll_loss(
                util.masked_log_softmax(all_modified_start_logits, start_mask),
                span_start[:, 0].squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(all_modified_end_logits, end_mask),
                span_end[:, 0].squeeze(-1))
        else:
            loss = Variable(torch.Tensor([0]).cuda(cuda_device))
        output_dict["loss"] = loss
        # when validating, compute the ROUGE score
        if metadata is not None:
            # find best span of all the paragraphs
            best_span = self.get_best_span(span_start_logits, span_end_logits)
            best_span = best_span.view(batch_size, num_passage, 3)
            # extract answer for computing Rouge, F1 and EM
            output_dict['best_span_str'] = []
            question_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                all_passages = metadata[i]['all_passages']
                passage_offsets = metadata[i]['passage_offsets']
                # get the paragraph with highest confidence span
                _, max_id = torch.max(best_span[i, :, 2], dim=0)
                max_id = int(max_id)
                # extract answer text
                predicted_span = tuple(best_span[i, max_id].data.cpu().numpy())
                start_offset = passage_offsets[max_id][int(
                    predicted_span[0])][0]
                end_offset = passage_offsets[max_id][int(predicted_span[1])][1]
                best_span_string = all_passages[max_id][
                    start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._ite += 1
                    if (self._ite % 100 == 0):
                        print("%s || %s" % (best_span_string, answer_texts))
                    self._squad_metrics(best_span_string, answer_texts)
                    self._rouge_metric(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
        return output_dict
Exemple #33
0
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:

        # Shape: (batch_size, seq_length, embedding_dim)
        embedded_p = self._text_field_embedder(premise)
        embedded_h = self._text_field_embedder(hypothesis)

        mask_p = get_text_field_mask(premise).float()
        mask_h = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_p = self.rnn_input_dropout(embedded_p)
            embedded_h = self.rnn_input_dropout(embedded_h)

        # encode p and h
        # Shape: (batch_size, seq_length, encoding_direction_num * encoding_hidden_dim)
        encoded_p = self._encoder(embedded_p, mask_p)
        encoded_h = self._encoder(embedded_h, mask_h)

        # Shape: (batch_size, p_length, h_length)
        similarity_matrix = self._matrix_attention(encoded_p, encoded_h)

        # Shape: (batch_size, p_length, h_length)
        p2h_attention = last_dim_softmax(similarity_matrix, mask_h)
        # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim)
        attended_h = weighted_sum(encoded_h, p2h_attention)

        # Shape: (batch_size, h_length, p_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), mask_p)
        # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim)
        attended_p = weighted_sum(encoded_p, h2p_attention)

        # the "enhancement" layer
        # Shape: (batch_size, p_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching)
        enhanced_p = torch.cat([
            encoded_p, attended_h, encoded_p - attended_h,
            encoded_p * attended_h
        ],
                               dim=-1)
        # Shape: (batch_size, h_length, encoding_direction_num * encoding_hidden_dim * 4 + num_perspective * num_matching)
        enhanced_h = torch.cat([
            encoded_h, attended_p, encoded_h - attended_p,
            encoded_h * attended_p
        ],
                               dim=-1)

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        # Shape: (batch_size, seq_length, projection_hidden_dim)
        projected_enhanced_p = self._projection_feedforward(enhanced_p)
        projected_enhanced_h = self._projection_feedforward(enhanced_h)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_p = self.rnn_input_dropout(projected_enhanced_p)
            projected_enhanced_h = self.rnn_input_dropout(projected_enhanced_h)

        # Shape: (batch_size, seq_length, inference_direction_num * inference_hidden_dim)
        inferenced_p = self._inference_encoder(projected_enhanced_p, mask_p)
        inferenced_h = self._inference_encoder(projected_enhanced_h, mask_h)

        # The pooling layer -- max and avg pooling.
        # Shape: (batch_size, inference_direction_num * inference_hidden_dim)
        pooled_p_max, _ = replace_masked_values(inferenced_p,
                                                mask_p.unsqueeze(-1),
                                                -1e7).max(dim=1)
        pooled_h_max, _ = replace_masked_values(inferenced_h,
                                                mask_h.unsqueeze(-1),
                                                -1e7).max(dim=1)

        pooled_p_avg = torch.sum(inferenced_p * mask_p.unsqueeze(-1),
                                 dim=1) / torch.sum(mask_p, 1, keepdim=True)
        pooled_h_avg = torch.sum(inferenced_h * mask_h.unsqueeze(-1),
                                 dim=1) / torch.sum(mask_h, 1, keepdim=True)

        # Now concat
        # Shape: (batch_size, inference_direction_num * inference_hidden_dim * 2)
        pooled_p_all = torch.cat([pooled_p_avg, pooled_p_max], dim=1)
        pooled_h_all = torch.cat([pooled_h_avg, pooled_h_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            pooled_p_all = self.dropout(pooled_p_all)
            pooled_h_all = self.dropout(pooled_h_all)

        # Shape: (batch_size, output_feedforward_hidden_dim)
        output_p, output_h = self._output_feedforward(pooled_p_all,
                                                      pooled_h_all)

        distance = F.pairwise_distance(output_p, output_h)
        prediction = distance < (self._margin / 2.0)
        output_dict = {'distance': distance, "prediction": prediction}

        if label is not None:
            """
            Contrastive loss function.
            Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
            """
            y = label.float()
            l1 = y * torch.pow(distance, 2) / 2.0
            l2 = (1 - y) * torch.pow(
                torch.clamp(self._margin - distance, min=0.0), 2) / 2.0
            loss = torch.mean(l1 + l2)

            self._accuracy(prediction, label.byte())

            output_dict["loss"] = loss

        return output_dict
    def forward(self, s1, s2):
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        s1 : Dict[str, torch.LongTensor]
            From a ``TextField``.
        s2 : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this s2 contains the answer to the
            s1, and predicts the beginning and ending positions of the answer within the
            s2.

        Returns
        -------
        pair_rep : torch.FloatTensor?
            Tensor representing the final output of the BiDAF model
            to be plugged into the next module

        """
        s1_embs = self._highway_layer(self._text_field_embedder(s1))
        s2_embs = self._highway_layer(self._text_field_embedder(s2))
        if self._elmo is not None:
            s1_elmo_embs = self._elmo(s1['elmo'])
            s2_elmo_embs = self._elmo(s2['elmo'])
            if "words" in s1:
                s1_embs = torch.cat(
                    [s1_embs, s1_elmo_embs['elmo_representations'][0]], dim=-1)
                s2_embs = torch.cat(
                    [s2_embs, s2_elmo_embs['elmo_representations'][0]], dim=-1)
            else:
                s1_embs = s1_elmo_embs['elmo_representations'][0]
                s2_embs = s2_elmo_embs['elmo_representations'][0]
        if self._cove is not None:
            s1_lens = torch.ne(s1['words'],
                               self.pad_idx).long().sum(dim=-1).data
            s2_lens = torch.ne(s2['words'],
                               self.pad_idx).long().sum(dim=-1).data
            s1_cove_embs = self._cove(s1['words'], s1_lens)
            s1_embs = torch.cat([s1_embs, s1_cove_embs], dim=-1)
            s2_cove_embs = self._cove(s2['words'], s2_lens)
            s2_embs = torch.cat([s2_embs, s2_cove_embs], dim=-1)
        s1_embs = self._dropout(s1_embs)
        s2_embs = self._dropout(s2_embs)

        if self._mask_lstms:
            s1_mask = s1_lstm_mask = util.get_text_field_mask(s1).float()
            s2_mask = s2_lstm_mask = util.get_text_field_mask(s2).float()
            s1_mask_2 = util.get_text_field_mask(s1).float()
            s2_mask_2 = util.get_text_field_mask(s2).float()
        else:
            s1_lstm_mask, s2_lstm_mask, s2_lstm_mask_2 = None, None, None

        s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask)
        s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask)

        # Similarity matrix
        # Shape: (batch_size, s2_length, s1_length)
        similarity_mat = self._matrix_attention(s2_enc, s1_enc)

        # s2 representation
        # Shape: (batch_size, s2_length, s1_length)
        s2_s1_attention = util.last_dim_softmax(similarity_mat, s1_mask)
        # Shape: (batch_size, s2_length, encoding_dim)
        s2_s1_vectors = util.weighted_sum(s1_enc, s2_s1_attention)
        # batch_size, seq_len, 4*enc_dim
        s2_w_context = torch.cat([s2_enc, s2_s1_vectors], 2)
        # s1 representation, using same attn method as for the s2 representation
        s1_s2_attention = util.last_dim_softmax(
            similarity_mat.transpose(1, 2).contiguous(), s2_mask)
        # Shape: (batch_size, s1_length, encoding_dim)
        s1_s2_vectors = util.weighted_sum(s2_enc, s1_s2_attention)
        s1_w_context = torch.cat([s1_enc, s1_s2_vectors], 2)
        if self._elmo is not None and self._deep_elmo:
            s1_w_context = torch.cat(
                [s1_w_context, s1_elmo_embs['elmo_representations'][1]],
                dim=-1)
            s2_w_context = torch.cat(
                [s2_w_context, s2_elmo_embs['elmo_representations'][1]],
                dim=-1)
        s1_w_context = self._dropout(s1_w_context)
        s2_w_context = self._dropout(s2_w_context)

        modeled_s2 = self._dropout(
            self._modeling_layer(s2_w_context, s2_lstm_mask))
        s2_mask_2 = s2_mask_2.unsqueeze(dim=-1)
        modeled_s2.data.masked_fill_(1 - s2_mask_2.byte().data, -float('inf'))
        s2_enc_attn = modeled_s2.max(dim=1)[0]
        modeled_s1 = self._dropout(
            self._modeling_layer(s1_w_context, s1_lstm_mask))
        s1_mask_2 = s1_mask_2.unsqueeze(dim=-1)
        modeled_s1.data.masked_fill_(1 - s1_mask_2.byte().data, -float('inf'))
        s1_enc_attn = modeled_s1.max(dim=1)[0]

        return torch.cat([
            s1_enc_attn, s2_enc_attn,
            torch.abs(s1_enc_attn - s2_enc_attn), s1_enc_attn * s2_enc_attn
        ], 1)
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``.
        label : torch.LongTensor, optional (default = None)
            A variable representing the label for each instance in the batch.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a
            distribution over the label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        text_mask = util.get_text_field_mask(tokens).float()
        # Pop elmo tokens, since elmo embedder should not be present.
        elmo_tokens = tokens.pop("elmo", None)
        embedded_text = self._text_field_embedder(tokens)

        # Add the "elmo" key back to "tokens" if not None, since the tests and the
        # subsequent training epochs rely not being modified during forward()
        if elmo_tokens is not None:
            tokens["elmo"] = elmo_tokens

        # Create ELMo embeddings if applicable
        if self._elmo:
            if elmo_tokens is not None:
                elmo_representations = self._elmo(
                    elmo_tokens)["elmo_representations"]
                # Pop from the end is more performant with list
                if self._use_integrator_output_elmo:
                    integrator_output_elmo = elmo_representations.pop()
                if self._use_input_elmo:
                    input_elmo = elmo_representations.pop()
                assert not elmo_representations
            else:
                raise ConfigurationError(
                    "Model was built to use Elmo, but input text is not tokenized for Elmo."
                )

        if self._use_input_elmo:
            embedded_text = torch.cat([embedded_text, input_elmo], dim=-1)

        dropped_embedded_text = self._embedding_dropout(embedded_text)
        pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text)
        encoded_tokens = self._encoder(pre_encoded_text, text_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(
            encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.last_dim_softmax(attention_logits, text_mask)
        encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([
            encoded_tokens, encoded_tokens - encoded_text,
            encoded_tokens * encoded_text
        ], 2)
        integrated_encodings = self._integrator(integrator_input, text_mask)

        # Concatenate ELMo representations to integrated_encodings if specified
        if self._use_integrator_output_elmo:
            integrated_encodings = torch.cat(
                [integrated_encodings, integrator_output_elmo], dim=-1)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, text_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, text_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(
            text_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
            integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits, text_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings,
                                                self_weights)

        pooled_representations = torch.cat(
            [max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(
            pooled_representations)

        logits = self._output_layer(pooled_representations_dropped)
        class_probabilities = F.softmax(logits, dim=-1)

        output_dict = {
            'logits': logits,
            'class_probabilities': class_probabilities
        }
        if label is not None:
            loss = self.loss(logits, label)
            for metric in self.metrics.values():
                metric(logits, label)
            output_dict["loss"] = loss

        return output_dict
Exemple #36
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            for_training: bool = False) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))

        ###################################### selection ######################################
        pbx = sigmoid(self.linear(embedded_passage).squeeze(2))
        print(' passage: ', passage)

        # # pby = self.we_selector(embedded_y1)

        # assert pbx.size() == passage['tokens'].size()

        # torch byte tesnor Variable of size (batch x len)
        selection_x = pbx.bernoulli().long()  #(pbx>=threshold).long()
        # # selection_y = pby.bernoulli().long()#(pby>=threshold).long()

        result_x = passage['tokens'].mul(
            selection_x
        )  #word ids that are selected; contains zeros where it's not selected (ony selected can be found by selected_x[selected_x!=0])
        char_result_x = passage['token_characters'] * selection_x.unsqueeze(
            2).repeat(1, 1, passage['token_characters'].size()[2])
        # result_y = sentence2.mul(selection_y)
        # print('result_x: ', result_x)

        selected_x, char_selected_x = helper.get_selected_tensor(
            result_x, char_result_x, pbx, passage['tokens'],
            passage['token_characters'],
            self.cuda_device)  #sentence1_len is a numpy array
        print(' passage size: ', passage['tokens'], ' char_passage size: ',
              passage['token_characters'], ' selected_x: ', selected_x,
              ' char_selected_x: ', char_selected_x)
        # selected_y, sentence2_len = helper.get_selected_tensor(result_y, pby, sentence2, sentence2_len_old, self.config.cuda) #sentence2_len is a numpy array

        logpz = zsum = zdiff = -1.0
        if for_training:
            mask1 = (
                passage['tokens'] !=
                self._vocab.get_token_index(DEFAULT_PADDING_TOKEN)).long()
            #     mask2 = (sentence2!=0).long()

            masked_selection_x = selection_x.mul(mask1)
            #     masked_selection_y =  selection_y.mul(mask2)

            #     #logpz (batch x len)
            logpx = -helper.binary_cross_entropy(
                pbx, selection_x.float().detach(), reduce=False
            )  #as reduce is not available for this version I am doing this code myself:
            #     logpy = -helper.binary_cross_entropy(pby, selection_y.float().detach(), reduce = False)
            assert logpx.size() == passage['tokens'].size()

            #     # batch
            logpx = logpx.mul(mask1.float()).sum(1)
            #     logpy = logpy.mul(mask2.float()).sum(1)
            logpz = logpx  #(logpx+logpy)
            #     # zsum = ##### same as sentence1_len #####T.sum(z, axis=0, dtype=theano.config.floatX)
            zdiff1 = (
                masked_selection_x[:, 1:] - masked_selection_x[:, :-1]
            ).abs().sum(
                1
            )  ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)
            #     zdiff2 = (masked_selection_y[:,1:]-masked_selection_y[:,:-1]).abs().sum(1)  ####T.sum(T.abs_(z[1:]-z[:-1]), axis=0, dtype=theano.config.floatX)

            assert zdiff1.size()[0] == passage['tokens'].size()[0]
            #     assert logpz.size()[0] == sentence1.size()[0]

            zdiff = zdiff1  #+zdiff2

            xsum = masked_selection_x.sum(1)
            #     ysum = masked_selection_y.sum(1)
            zsum = xsum  #+ysum

            assert zsum.size()[0] == passage['tokens'].size()[0]

            assert logpz.dim() == zsum.dim()
            assert logpz.dim() == zdiff.dim()
        #     return selected_x, sentence1_len, selected_y, sentence2_len, logpz, zsum.float(), zdiff.float()

        passage['tokens'] = selected_x
        passage['token_characters'] = char_selected_x

        # print(' passage[tokens]: ', passage['tokens'], ' dim: ', passage['tokens'].dim())
        # print("selected_x: ", selected_x, ' dim: ', selected_x.dim())
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))

        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Exemple #37
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        # encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))

        # # v5:
        # # remember to set token embeddings in the CONFIG JSON
        encoded_question = self._dropout(embedded_question)

        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX
        similarity_matrix = self._matrix_attention(encoded_passage,
                                                   encoded_question)

        # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY
        passage_question_attention = util.last_dim_softmax(
            similarity_matrix, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Our custom query2context
        q2c_attention = util.masked_softmax(similarity_matrix,
                                            question_mask,
                                            dim=1).transpose(-1, -2)
        q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention)

        # Now we try the various variants
        # v1:
        # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention)

        # v2:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1]))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2)

        # v3:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim)

        # v4:
        # Re-application of query2context attention
        # new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs)
        # masked_similarity = util.replace_masked_values(new_similarity_matrix,
        #                                                question_mask.unsqueeze(1),
        #                                                -1e7)
        # # Shape: (batch_size, passage_length)
        # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # # Shape: (batch_size, passage_length)
        # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # # Shape: (batch_size, encoding_dim)
        # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # # Shape: (batch_size, passage_length, encoding_dim)
        # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
        #                                                                             passage_length,
        #                                                                             encoding_dim)

        # ------- Original variant
        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            similarity_matrix, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # ------- END

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        # original beta combination function
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        # # v6:
        # final_merged_passage = torch.cat([tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v7:
        # final_merged_passage = torch.cat([passage_question_vectors],
        #                                  dim=-1)
        #
        # # v8:
        # final_merged_passage = torch.cat([passage_question_vectors,
        #                                   tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v9:
        # final_merged_passage = torch.cat([encoded_passage,
        #                                   passage_question_vectors,
        #                                   encoded_passage * passage_question_vectors],
        #                                  dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = int(span_widths.max().data) + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(global_attention_logits,
                                                          span_indices,
                                                          flat_span_indices).squeeze(-1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.last_dim_softmax(span_attention_logits, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()

        return attended_text_embeddings