Esempio n. 1
0
 def test_weighted_sum_works_on_simple_input(self):
     batch_size = 1
     sentence_length = 5
     embedding_dim = 4
     sentence_array = numpy.random.rand(batch_size, sentence_length, embedding_dim)
     sentence_tensor = torch.from_numpy(sentence_array).float()
     attention_tensor = torch.FloatTensor([[.3, .4, .1, 0, 1.2]])
     aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy()
     assert aggregated_array.shape == (batch_size, embedding_dim)
     expected_array = (0.3 * sentence_array[0, 0] +
                       0.4 * sentence_array[0, 1] +
                       0.1 * sentence_array[0, 2] +
                       0.0 * sentence_array[0, 3] +
                       1.2 * sentence_array[0, 4])
     numpy.testing.assert_almost_equal(aggregated_array, [expected_array], decimal=5)
Esempio n. 2
0
 def test_weighted_sum_handles_higher_order_input(self):
     batch_size = 1
     length_1 = 5
     length_2 = 6
     length_3 = 2
     embedding_dim = 4
     sentence_array = numpy.random.rand(batch_size, length_1, length_2, length_3, embedding_dim)
     attention_array = numpy.random.rand(batch_size, length_1, length_2, length_3)
     sentence_tensor = torch.from_numpy(sentence_array).float()
     attention_tensor = torch.from_numpy(attention_array).float()
     aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy()
     assert aggregated_array.shape == (batch_size, length_1, length_2, embedding_dim)
     expected_array = (attention_array[0, 3, 2, 0] * sentence_array[0, 3, 2, 0] +
                       attention_array[0, 3, 2, 1] * sentence_array[0, 3, 2, 1])
     numpy.testing.assert_almost_equal(aggregated_array[0, 3, 2], expected_array, decimal=5)
Esempio n. 3
0
 def test_weighted_sum_handles_3d_attention_with_3d_matrix(self):
     batch_size = 1
     length_1 = 5
     length_2 = 2
     embedding_dim = 4
     sentence_array = numpy.random.rand(batch_size, length_2, embedding_dim)
     attention_array = numpy.random.rand(batch_size, length_1, length_2)
     sentence_tensor = torch.from_numpy(sentence_array).float()
     attention_tensor = torch.from_numpy(attention_array).float()
     aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy()
     assert aggregated_array.shape == (batch_size, length_1, embedding_dim)
     for i in range(length_1):
         expected_array = (attention_array[0, i, 0] * sentence_array[0, 0] +
                           attention_array[0, i, 1] * sentence_array[0, 1])
         numpy.testing.assert_almost_equal(aggregated_array[0, i], expected_array,
                                           decimal=5)
Esempio n. 4
0
    def forward(self, X, mask=None, verbose=False):
        '''
        Generate predictions


        Parameters
        ----------
        X: input with shape (batch_size, max_seq_len, input_dim)
        mask: input with shape (batch_size, max_seq_len)

        '''

        # Batch size
        batch_size = X.shape[0]

        # Batch vector (repeat across first dimension)
        vector = self.vector.unsqueeze(0).repeat(batch_size, 1)

        #
        if self.use_ffnn:
            Q = self.ffnn(X)
        else:
            Q = X

        # Attention weights
        # shape: (batch_size, max_seq_len)
        alphas = self.attention( \
                                vector = vector,
                                matrix = Q,
                                matrix_mask = mask)

        # Attended input
        # shape: (batch_size, encoder_query_dim)
        output = weighted_sum(X, alphas)

        # Dropout layer
        output = self.drop_layer(output)

        if verbose:
            logging.info('Attention')
            logging.info('\tinput_dim:  {}'.format(input_dim))
            logging.info('\tquery_dim: {}'.format(query_dim))
            logging.info('\tactivation: {}'.format(activation))
            logging.info('\tdropout:    {}'.format(dropout))
            logging.info('\tuse_ffnn:    {}'.format(use_ffnn))

        return (output, alphas)
Esempio n. 5
0
    def forward(self, tokens: torch.Tensor, mask: torch.Tensor):  # pylint: disable=arguments-differ
        batch_size, sequence_length, _ = tokens.size()
        # Shape: (batch_size, sequence_length, sequence_length)
        similarity_matrix = self._matrix_attention(tokens, tokens)

        if self._num_attention_heads > 1:
            # In this case, the similarity matrix actually has shape
            # (batch_size, sequence_length, sequence_length, num_heads).  To make the rest of the
            # logic below easier, we'll permute this to
            # (batch_size, sequence_length, num_heads, sequence_length).
            similarity_matrix = similarity_matrix.permute(0, 1, 3, 2)

        # Shape: (batch_size, sequence_length, [num_heads,] sequence_length)
        intra_sentence_attention = util.masked_softmax(
            similarity_matrix.contiguous(), mask)

        # Shape: (batch_size, sequence_length, projection_dim)
        output_token_representation = self._projection(tokens)

        if self._num_attention_heads > 1:
            # We need to split and permute the output representation to be
            # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can
            # do a proper weighted sum with `intra_sentence_attention`.
            shape = list(output_token_representation.size())
            new_shape = shape[:-1] + [self._num_attention_heads, -1]
            # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads)
            output_token_representation = output_token_representation.view(
                *new_shape)
            # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads)
            output_token_representation = output_token_representation.permute(
                0, 2, 1, 3)

        # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads])
        attended_sentence = util.weighted_sum(output_token_representation,
                                              intra_sentence_attention)

        if self._num_attention_heads > 1:
            # Here we concatenate the weighted representation for each head.  We'll accomplish this
            # just with a resize.
            # Shape: (batch_size, sequence_length, projection_dim)
            attended_sentence = attended_sentence.view(batch_size,
                                                       sequence_length, -1)

        # Shape: (batch_size, sequence_length, combination_dim)
        combined_tensors = util.combine_tensors(self._combination,
                                                [tokens, attended_sentence])
        return self._output_projection(combined_tensors)
Esempio n. 6
0
    def _compute_attention(
            self,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state.
        Parameters
        ----------
        decoder_hidden_state : ``torch.LongTensor``
            A tensor of shape ``(batch_size, decoder_output_dim)``, which contains the current decoder hidden state to be used
            as the 'query' to the attention computation
            during the last time step.
        encoder_outputs : ``torch.LongTensor``
            A tensor of shape ``(batch_size, max_input_sequence_length, encoder_output_dim)``, which contains all the
            encoder hidden states of the source tokens, i.e., the 'keys' to the attention computation
        encoder_mask : ``torch.LongTensor``
            A tensor of shape (batch_size, max_input_sequence_length), which contains the mask of the encoded input.
            We want to avoid computing an attention score for positions of the source with zero-values (remember not all
            input sentences have the same length)

        Returns
        -------
        torch.Tensor
            A tensor of shape (batch_size, encoder_output_dim) that contains the attended encoder outputs (aka context vector),
            i.e., we have ``applied`` the attention scores on the encoder hidden states.

        Notes
        -----
            Don't forget to apply the final softmax over the **masked** encoder outputs!
        """

        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # Main body of attention weights computation here
        attention_scores = encoder_outputs.bmm(
            decoder_hidden_state.unsqueeze(-1)).squeeze(-1)
        masked_attention_scores = masked_softmax(attention_scores,
                                                 encoder_outputs_mask)
        attended_output = util.weighted_sum(encoder_outputs,
                                            masked_attention_scores)

        # masked_softmax()
        return attended_output, masked_attention_scores
    def _prepare_decode_step_input(
            self,
            input_indices: torch.LongTensor,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        """
        Given the input indices for the current timestep of the decoder, and all the encoder
        outputs, compute the input at the current timestep.  Note: This method is agnostic to
        whether the indices are gold indices or the predictions made by the decoder at the last
        timestep. So, this can be used even if we're doing some kind of scheduled sampling.

        If we're not using attention, the output of this method is just an embedding of the input
        indices.  If we are, the output will be a concatentation of the embedding and an attended
        average of the encoder inputs.

        Parameters
        ----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
            Output of from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self._target_embedder(input_indices)
        if self._attention_function:
            # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
            # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
            # complain.
            encoder_outputs_mask = encoder_outputs_mask.float()
            # (batch_size, input_sequence_length)
            input_weights = self._decoder_attention(decoder_hidden_state,
                                                    encoder_outputs,
                                                    encoder_outputs_mask)
            # (batch_size, encoder_output_dim)
            attended_input = weighted_sum(encoder_outputs, input_weights)
            # (batch_size, encoder_output_dim + target_embedding_dim)
            return torch.cat((attended_input, embedded_input), -1)
        else:
            return embedded_input
Esempio n. 8
0
    def _compute_attention(self, encoder_outputs: torch.Tensor,
                           encoder_mask: torch.Tensor,
                           decoder_outputs: torch.Tensor) -> torch.Tensor:
        """
        Computes the attention-based decoder hidden representation by first
        computing the attention scores between the encoder and decoder hidden
        states, computing the attention context via a weighted average over
        the encoder hidden states, concatenating the decoder state with the
        context, and passing the result through the attention layer to project
        it back down to the decoder hidden state size.

        Parameters
        ----------
        encoder_outputs: ``torch.Tensor``, ``(batch_size, num_document_tokens, encoder_hidden_size)``
            The output from the encoder.
        encoder_mask: ``torch.Tensor``, ``(batch_size, num_document_tokens)``
            The document token mask.
        decoder_outputs: ``torch.Tensor``, ``(batch_size, num_summary_tokens, decoder_hidden_size)``
            The output from the decoder.

        Returns
        -------
        hidden: ``torch.Tensor``, ``(batch_size, num_summary_tokens, decoder_hidden_size)``
            The new decoder hidden state representation.
        attention_probabilities: ``torch.Tensor``, ``(batch_size, num_summary_tokens, num_document_tokens)``
            The attention probabilities over the document tokens for each summary token
        """
        # Compute the attention context
        # shape: (group_size, num_summary_tokens, num_document_tokens)
        attention_scores = self.attention(decoder_outputs, encoder_outputs)
        # shape: (group_size, num_summary_tokens, num_document_tokens)
        attention_probabilities = masked_softmax(attention_scores,
                                                 encoder_mask)
        # shape: (group_size, num_summary_tokens, encoder_hidden_size)
        attention_context = weighted_sum(encoder_outputs,
                                         attention_probabilities)

        # Concatenate the attention context with the decoder outputs
        # then project back to the decoder hidden size
        # shape: (group_size, num_summary_tokens, encoder_hidden_size + decoder_hidden_size)
        concat = torch.cat([attention_context, decoder_outputs], dim=2)

        # shape: (group_size, num_summary_tokens, decoder_hidden_size)
        projected_hidden = self.attention_layer(concat)
        return projected_hidden, attention_probabilities
Esempio n. 9
0
    def _prepare_attended_input(self,
                                decoder_hidden_state: torch.LongTensor = None,
                                encoder_outputs: torch.LongTensor = None,
                                encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(
                decoder_hidden_state, encoder_outputs, encoder_outputs_mask)

        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input
Esempio n. 10
0
 def test_weighted_sum_handles_3d_attention_with_3d_matrix(self):
     batch_size = 1
     length_1 = 5
     length_2 = 2
     embedding_dim = 4
     sentence_array = numpy.random.rand(batch_size, length_2, embedding_dim)
     attention_array = numpy.random.rand(batch_size, length_1, length_2)
     sentence_tensor = Variable(torch.from_numpy(sentence_array).float())
     attention_tensor = Variable(torch.from_numpy(attention_array).float())
     aggregated_array = util.weighted_sum(sentence_tensor,
                                          attention_tensor).data.numpy()
     assert aggregated_array.shape == (batch_size, length_1, embedding_dim)
     for i in range(length_1):
         expected_array = (attention_array[0, i, 0] * sentence_array[0, 0] +
                           attention_array[0, i, 1] * sentence_array[0, 1])
         numpy.testing.assert_almost_equal(aggregated_array[0, i],
                                           expected_array,
                                           decimal=5)
Esempio n. 11
0
 def test_weighted_sum_works_on_simple_input(self):
     batch_size = 1
     sentence_length = 5
     embedding_dim = 4
     sentence_array = numpy.random.rand(batch_size, sentence_length,
                                        embedding_dim)
     sentence_tensor = Variable(torch.from_numpy(sentence_array).float())
     attention_tensor = Variable(torch.FloatTensor([[.3, .4, .1, 0, 1.2]]))
     aggregated_array = util.weighted_sum(sentence_tensor,
                                          attention_tensor).data.numpy()
     assert aggregated_array.shape == (batch_size, embedding_dim)
     expected_array = (0.3 * sentence_array[0, 0] +
                       0.4 * sentence_array[0, 1] +
                       0.1 * sentence_array[0, 2] +
                       0.0 * sentence_array[0, 3] +
                       1.2 * sentence_array[0, 4])
     numpy.testing.assert_almost_equal(aggregated_array, [expected_array],
                                       decimal=5)
Esempio n. 12
0
 def test_weighted_sum_handles_uneven_higher_order_input(self):
     batch_size = 1
     length_1 = 5
     length_2 = 6
     length_3 = 2
     embedding_dim = 4
     sentence_array = numpy.random.rand(batch_size, length_3, embedding_dim)
     attention_array = numpy.random.rand(batch_size, length_1, length_2, length_3)
     sentence_tensor = torch.from_numpy(sentence_array).float()
     attention_tensor = torch.from_numpy(attention_array).float()
     aggregated_array = util.weighted_sum(sentence_tensor, attention_tensor).data.numpy()
     assert aggregated_array.shape == (batch_size, length_1, length_2, embedding_dim)
     for i in range(length_1):
         for j in range(length_2):
             expected_array = (attention_array[0, i, j, 0] * sentence_array[0, 0] +
                               attention_array[0, i, j, 1] * sentence_array[0, 1])
             numpy.testing.assert_almost_equal(aggregated_array[0, i, j], expected_array,
                                               decimal=5)
    def _encode(
        self, ref_target_tokens: Dict[str, torch.Tensor],
        ref_source_tokens: Dict[str, torch.Tensor],
        instance_source_tokens: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_ref_target = self._target_embedder(ref_target_tokens)
        ref_target_mask = util.get_text_field_mask(ref_target_tokens)
        encoded_ref_target = self._target_encoder(embedded_ref_target,
                                                  ref_target_mask)

        embedded_ref_source = self._source_embedder(ref_source_tokens)
        ref_source_mask = util.get_text_field_mask(ref_source_tokens)
        encoded_ref_source = self._source_encoder(embedded_ref_source,
                                                  ref_source_mask)

        embedded_instance_source = self._source_embedder(
            instance_source_tokens)
        instance_source_mask = util.get_text_field_mask(instance_source_tokens)
        encoded_instance_source = self._source_encoder(
            embedded_instance_source, instance_source_mask)

        #reduced_encoded_ref_source = torch.sum(encoded_ref_source,dim=1).unsqueeze(1)
        #reduced_encoded_instance_source = torch.sum(encoded_instance_source,dim=1).unsqueeze(1)
        #embedding_matrix = util.weighted_sum(reduced_encoded_instance_source,reduced_encoded_ref_source.permute(0,2,1))
        ref_instance_similarity = self._instance_ref_sim(
            encoded_ref_source, encoded_instance_source)
        #ref_instance_similarity = -1*ref_instance_similarity
        #ref_instance_attention = util.masked_softmax(ref_instance_similarity, instance_source_mask)
        instance_ref_vectors = util.weighted_sum(encoded_instance_source,
                                                 ref_instance_similarity)

        #target_instance_ref_similarity = self._target_instref_sim(encoded_ref_target,instance_ref_vectors)
        #source_target_attention = util.masked_softmax(target_instance_ref_similarity,instance_source_mask)
        #target_vectors = util.weighted_sum(instance_ref_vectors,source_target_attention)

        #print('mask',instance_source_mask.shape)
        #print('out',target_vectors.shape)
        return {
            "source_mask": ref_target_mask,
            "encoder_outputs": encoded_ref_target,
            "attention_over_src": instance_ref_vectors,
            "attention_mask": ref_source_mask
        }
Esempio n. 14
0
    def _compute_attention(self, sentence_encodings: torch.Tensor,
                           context_encodings: torch.Tensor,
                           context_mask: torch.Tensor) -> torch.Tensor:
        """
        Computes new sentence encodings using an attention mechanism between
        the original sentence encodings and some context encodings. The context
        encodings are not necessarily the context in the cloze task sense, but
        any vector over which the attention should be computed.

        Parameters
        ----------
        sentence_encodings: (batch_size, num_sents, hidden_dim)
            The original sentence encodings
        context_encodings: (batch_size, num_contexts, hidden_dim)
            The representation of each context item
        context_mask: (batch_size, num_contexts)
            The context item mask

        Returns
        -------
        The new sentence encodings: (batch_size, num_sents, hidden_dim)
        """
        if self.attention is None or self.attention_layer is None:
            raise Exception(
                '`attention` and `attention_layer` must not be `None` to use attention'
            )

        # shape: (batch_size, num_sents, num_context_tokens)
        attention_scores = self.attention(sentence_encodings,
                                          context_encodings)
        # shape: (batch_size, num_sents, num_context_tokens)
        attention_probabilities = masked_softmax(attention_scores,
                                                 context_mask)
        # shape: (batch_size, num_sents, hidden_size)
        attention_context = weighted_sum(context_encodings,
                                         attention_probabilities)

        # Concatenate the attention context with the sentence encodings
        # then project back to the sentence encoder hidden size
        # shape: (batch_size, num_sents, hidden_size * 2)
        concat = torch.cat([attention_context, sentence_encodings], dim=2)
        # shape: (batch_size, num_sents, hidden_size)
        projected_hidden = self.attention_layer(concat)
        return projected_hidden
Esempio n. 15
0
    def attend_on_sentence(self, query, encoder_outputs, encoder_output_mask):
        u"""
        This method is almost identical to ``WikiTablesDecoderStep.attend_on_question``. We just
        don't return the attention weights.
        Given a query (which is typically the decoder hidden state), compute an attention over the
        output of the sentence encoder, and return a weighted sum of the sentence representations
        given this attention.  We also return the attention weights themselves.

        This is a simple computation, but we have it as a separate method so that the ``forward``
        method on the main parser module can call it on the initial hidden state, to simplify the
        logic in ``take_step``.
        """
        # (group_size, sentence_length)
        sentence_attention_weights = self._input_attention(
            query, encoder_outputs, encoder_output_mask)
        # (group_size, encoder_output_dim)
        attended_sentence = nn_util.weighted_sum(encoder_outputs,
                                                 sentence_attention_weights)
        return attended_sentence
Esempio n. 16
0
    def forward(
        self,
        sequence_tensor: torch.FloatTensor,
        span_indices: torch.LongTensor,
        span_indices_mask: torch.BoolTensor = None,
    ) -> torch.FloatTensor:
        # shape (batch_size, sequence_length, 1)

        global_attention_logits = torch.matmul(
            sequence_tensor,
            torch.zeros(self.input_dim, 1).to_device(sequence_tensor.device()))

        # shape (batch_size, sequence_length, embedding_dim + 1)
        concat_tensor = torch.cat([sequence_tensor, global_attention_logits],
                                  -1)

        concat_output, span_mask = util.batched_span_select(
            concat_tensor, span_indices)

        print(span_mask)

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = concat_output[:, :, :, :-1]
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = concat_output[:, :, :, -1]

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits,
                                                     span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings,
                                                     span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1)

        return attended_text_embeddings
    def prepare_decode_step_input(
            self, input_indices: torch.LongTensor,
            decoder_hidden: torch.LongTensor,
            encoder_outputs: torch.LongTensor,
            encoder_outputs_mask: torch.LongTensor) -> torch.LongTensor:
        """
        Prepares the current timestep input for the decoder.

        By default, simply embeds and returns the input. If using attention, the default attention
        (BiLinearAttention) is applied to attend on the step input given the encoded source
        sequence and the previous hidden state.

        Parameters:
        -----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden : torch.LongTensor, optional (not needed if no attention)
            Output from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self.target_embedder(input_indices)
        if self.apply_attention:
            if isinstance(decoder_hidden, tuple):
                decoder_hidden = decoder_hidden[0]
            # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
            # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
            # complain.
            encoder_outputs_mask = encoder_outputs_mask.float()
            # (batch_size, input_sequence_length)
            input_weights = self.decoder_attention_function(
                decoder_hidden[-1], encoder_outputs, encoder_outputs_mask)
            # (batch_size, encoder_output_dim)
            attended_input = util.weighted_sum(encoder_outputs, input_weights)
            # (batch_size, encoder_output_dim + target_embedding_dim)
            return torch.cat((attended_input, embedded_input), -1)
        else:
            return embedded_input
Esempio n. 18
0
    def _prepare_decode_step_input(self,
                                   input_indices: torch.LongTensor,
                                   decoder_hidden_state: torch.LongTensor = None,
                                   encoder_outputs: torch.LongTensor = None,
                                   encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        """
        Given the input indices for the current timestep of the decoder, and all the encoder
        outputs, compute the input at the current timestep.  Note: This method is agnostic to
        whether the indices are gold indices or the predictions made by the decoder at the last
        timestep. So, this can be used even if we're doing some kind of scheduled sampling.

        If we're not using attention, the output of this method is just an embedding of the input
        indices.  If we are, the output will be a concatentation of the embedding and an attended
        average of the encoder inputs.

        Parameters
        ----------
        input_indices : torch.LongTensor
            Indices of either the gold inputs to the decoder or the predicted labels from the
            previous timestep.
        decoder_hidden_state : torch.LongTensor, optional (not needed if no attention)
            Output of from the decoder at the last time step. Needed only if using attention.
        encoder_outputs : torch.LongTensor, optional (not needed if no attention)
            Encoder outputs from all time steps. Needed only if using attention.
        encoder_outputs_mask : torch.LongTensor, optional (not needed if no attention)
            Masks on encoder outputs. Needed only if using attention.
        """
        # input_indices : (batch_size,)  since we are processing these one timestep at a time.
        # (batch_size, target_embedding_dim)
        embedded_input = self._target_embedder(input_indices)
        if self._attention_function:
            # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
            # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
            # complain.
            encoder_outputs_mask = encoder_outputs_mask.float()
            # (batch_size, input_sequence_length)
            input_weights = self._decoder_attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
            # (batch_size, encoder_output_dim)
            attended_input = weighted_sum(encoder_outputs, input_weights)
            # (batch_size, encoder_output_dim + target_embedding_dim)
            return torch.cat((attended_input, embedded_input), -1)
        else:
            return embedded_input
    def attend_on_question(
        self, query: torch.Tensor, encoder_outputs: torch.Tensor, encoder_output_mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Given a query (which is typically the decoder hidden state), compute an attention over the
        output of the question encoder, and return a weighted sum of the question representations
        given this attention.  We also return the attention weights themselves.

        This is a simple computation, but we have it as a separate method so that the ``forward``
        method on the main parser module can call it on the initial hidden state, to simplify the
        logic in ``take_step``.
        """
        # (group_size, question_length)
        question_attention_weights = self._input_attention(
            query, encoder_outputs, encoder_output_mask
        )
        # (group_size, encoder_output_dim)
        attended_question = util.weighted_sum(encoder_outputs, question_attention_weights)
        return attended_question, question_attention_weights
Esempio n. 20
0
    def _prepare_output_projections(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]
        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]
        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]
        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        is_unk = (last_predictions >= self._target_vocab_size).long()
        last_predictions_fixed = last_predictions - last_predictions * is_unk + self._target_unk_index * is_unk
        embedded_input = self._target_embedder.forward(last_predictions_fixed)

        if not self._use_coverage:
            attn_scores = self._attention.forward(decoder_hidden,
                                                  encoder_outputs, source_mask)
        else:
            coverage = state["coverage"]
            attn_scores = self._attention.forward(decoder_hidden,
                                                  encoder_outputs, source_mask,
                                                  coverage)
            coverage = coverage + attn_scores
            state["coverage"] = coverage
        attn_context = util.weighted_sum(encoder_outputs, attn_scores)
        decoder_input = torch.cat((attn_context, embedded_input), -1)

        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input, (decoder_hidden, decoder_context))

        output_projections = self._output_projection_layer(
            self._hidden_projection_layer(decoder_hidden))

        state["decoder_input"] = decoder_input
        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context
        state["attn_scores"] = attn_scores
        state["attn_context"] = attn_context

        return output_projections, state
    def forward(self, tokens: torch.Tensor, mask: torch.Tensor):  # pylint: disable=arguments-differ
        batch_size, sequence_length, _ = tokens.size()
        # Shape: (batch_size, sequence_length, sequence_length)
        similarity_matrix = self._matrix_attention(tokens, tokens)

        if self._num_attention_heads > 1:
            # In this case, the similarity matrix actually has shape
            # (batch_size, sequence_length, sequence_length, num_heads).  To make the rest of the
            # logic below easier, we'll permute this to
            # (batch_size, sequence_length, num_heads, sequence_length).
            similarity_matrix = similarity_matrix.permute(0, 1, 3, 2)

        # Shape: (batch_size, sequence_length, [num_heads,] sequence_length)
        intra_sentence_attention = util.masked_softmax(similarity_matrix.contiguous(), mask)

        # Shape: (batch_size, sequence_length, projection_dim)
        output_token_representation = self._projection(tokens)

        if self._num_attention_heads > 1:
            # We need to split and permute the output representation to be
            # (batch_size, num_heads, sequence_length, projection_dim / num_heads), so that we can
            # do a proper weighted sum with `intra_sentence_attention`.
            shape = list(output_token_representation.size())
            new_shape = shape[:-1] + [self._num_attention_heads, -1]
            # Shape: (batch_size, sequence_length, num_heads, projection_dim / num_heads)
            output_token_representation = output_token_representation.view(*new_shape)
            # Shape: (batch_size, num_heads, sequence_length, projection_dim / num_heads)
            output_token_representation = output_token_representation.permute(0, 2, 1, 3)

        # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads])
        attended_sentence = util.weighted_sum(output_token_representation,
                                              intra_sentence_attention)

        if self._num_attention_heads > 1:
            # Here we concatenate the weighted representation for each head.  We'll accomplish this
            # just with a resize.
            # Shape: (batch_size, sequence_length, projection_dim)
            attended_sentence = attended_sentence.view(batch_size, sequence_length, -1)

        # Shape: (batch_size, sequence_length, combination_dim)
        combined_tensors = util.combine_tensors(self._combination, [tokens, attended_sentence])
        return self._output_projection(combined_tensors)
    def attend_on_question(self,
                           query: torch.Tensor,
                           encoder_outputs: torch.Tensor,
                           encoder_output_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Given a query (which is typically the decoder hidden state), compute an attention over the
        output of the question encoder, and return a weighted sum of the question representations
        given this attention.  We also return the attention weights themselves.

        This is a simple computation, but we have it as a separate method so that the ``forward``
        method on the main parser module can call it on the initial hidden state, to simplify the
        logic in ``take_step``.
        """
        # (group_size, question_length)
        question_attention_weights = self._input_attention(query,
                                                           encoder_outputs,
                                                           encoder_output_mask)
        # (group_size, encoder_output_dim)
        attended_question = util.weighted_sum(encoder_outputs, question_attention_weights)
        return attended_question, question_attention_weights
Esempio n. 23
0
 def _decode_step_output(
         self,
         decoder_hidden_state: torch.LongTensor = None,
         encoder_outputs: torch.LongTensor = None,
         encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
     # encoder_outputs : (batch_size, input_sequence_length, encoder_output_dim)
     # Ensuring mask is also a FloatTensor. Or else the multiplication within attention will
     # complain.
     encoder_outputs_mask = encoder_outputs_mask.float()
     # (batch_size, input_sequence_length)
     input_weights_e = self._decoder_attention(decoder_hidden_state,
                                               encoder_outputs,
                                               encoder_outputs_mask)
     input_weights_a = F.softmax(input_weights_e, dim=-1)
     # (batch_size, encoder_output_dim)
     attended_input = weighted_sum(encoder_outputs, input_weights_a)
     #H*_t = sum(h_i*at_i)
     # (batch_size, encoder_output_dim + decoder_hidden_dim)
     return input_weights_e, input_weights_a, torch.cat(
         (decoder_hidden_state, attended_input), -1)
Esempio n. 24
0
 def filter(self, objects: ObjectSet,
            question_attention: torch.Tensor) -> ObjectSet:
     if self.nmn_settings["mask_non_attention"]:
         language_encoding = self.language_encoding[
             self.object_scores_index]
         visual_encoding = self.visual_encoding[self.object_scores_index]
         cross_encoding = self.cross_encoding[self.object_scores_index]
     else:
         language_encoding = self.language_encoding
         visual_encoding = self.visual_encoding
         cross_encoding = self.cross_encoding
     question_attention = torch.nn.functional.pad(
         question_attention.view(-1),
         pad=(0, language_encoding.shape[0] - question_attention.numel()),
     )
     attended_question = util.weighted_sum(language_encoding,
                                           question_attention)
     attended_question = attended_question.view(1, -1).repeat(
         visual_encoding.shape[0], 1)
     attended_objects = visual_encoding
     if self.nmn_settings["filter_find_same_params"]:
         attended_objects = visual_encoding.view(-1,
                                                 visual_encoding.shape[-1])
         find_layer_inputs = torch.cat(
             (attended_objects, attended_question), dim=-1)
         logits = self.parameters.find_layer(find_layer_inputs)
         filter_probs = torch.sigmoid(logits).view(-1)
     elif self.nmn_settings["use_sum_counting"]:
         attended_objects = visual_encoding
         logits = self.parameters.filter_layer(
             torch.cat(
                 (
                     visual_encoding.view(-1, visual_encoding.shape[-1]),
                     attended_question.repeat_interleave(
                         visual_encoding.shape[1], dim=0),
                 ),
                 dim=-1,
             ))
         filter_probs = torch.sigmoid(logits).view(-1)
     self.object_scores[self.object_scores_index] = filter_probs
     return filter_probs * objects
Esempio n. 25
0
    def _prepare_attended_input(
        self,
        decoder_hidden_state: torch.LongTensor = None,
        encoder_outputs: torch.LongTensor = None,
        encoder_outputs_mask: torch.LongTensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()
        # shape: (batch_size, max_input_sequence_length)
        input_logits = self._attention(decoder_hidden_state, encoder_outputs,
                                       encoder_outputs_mask)
        # the attention mechanism returns the logits that are necessary for attention supervision loss,
        # so we normalize it here
        input_weights = masked_softmax(input_logits, encoder_outputs_mask)
        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input, input_logits
Esempio n. 26
0
    def _forward_internal(self, text: Dict[str, torch.Tensor],
                          relations: Dict[str, torch.Tensor]) -> torch.Tensor:
        t_mask = util.get_text_field_mask(text)
        t_emb = self.word_embeddings(text)

        t_hiddens = self.encoder_dropout(self.text_encoder(t_emb, t_mask))
        t_encoding = self.text_attn(t_hiddens)

        r_masks = util.get_text_field_mask(relations, num_wrapping_dims=1)
        r_embs = self.word_embeddings(relations)

        r_sentence_encodings = self.encoder_dropout(
            seq_over_seq(self.relation_encoder, r_embs, r_masks))
        r_attn = self.relation_attn(vector=t_encoding,
                                    matrix=r_sentence_encodings)
        r_encoding = util.weighted_sum(r_sentence_encodings, r_attn)

        final = torch.cat((t_encoding, r_encoding), dim=-1)

        logit = self.output(final).squeeze(-1)
        return cast(torch.Tensor, logit)
Esempio n. 27
0
 def with_relation(self, a: ObjectSet, b: ObjectSet,
                   question_attention: torch.Tensor) -> ObjectSet:
     if self.nmn_settings["mask_non_attention"]:
         language_encoding = self.language_encoding[
             self.object_scores_index]
         visual_encoding = self.visual_encoding[self.object_scores_index]
         cross_encoding = self.cross_encoding[self.object_scores_index]
     else:
         language_encoding = self.language_encoding
         visual_encoding = self.visual_encoding
         cross_encoding = self.cross_encoding
     question_attention = torch.nn.functional.pad(
         question_attention.view(-1),
         pad=(0, language_encoding.shape[0] - question_attention.numel()),
     )
     attended_question = util.weighted_sum(language_encoding,
                                           question_attention)
     attended_question = attended_question.view(1, -1).repeat(
         visual_encoding.shape[0]**2, 1)
     visual_encoding_rows = visual_encoding.repeat(
         1, visual_encoding.shape[0]).view(visual_encoding.shape[0]**2, -1)
     visual_encoding_columns = visual_encoding.repeat(
         visual_encoding.shape[0], 1)
     relate_layer_inputs = torch.cat(
         (visual_encoding_rows, visual_encoding_columns, attended_question),
         dim=-1)
     object_pair_scores = torch.sigmoid(
         self.parameters.relate_layer(relate_layer_inputs))
     object_pair_scores = object_pair_scores.view(visual_encoding.shape[0],
                                                  visual_encoding.shape[0])
     object_pair_scores2 = object_pair_scores * (
         1.0 - torch.eye(visual_encoding.shape[0]).cuda())
     object_pair_scores3 = object_pair_scores2 * a.view(-1, 1).repeat(
         1, visual_encoding.shape[0])
     object_pair_scores4 = object_pair_scores3 * b.view(1, -1).repeat(
         visual_encoding.shape[0], 1)
     object_scores = object_pair_scores4.sum(1).clamp(
         min=0.0 + self.parameters.epsilon,
         max=1.0 - self.parameters.epsilon)
     return object_scores
Esempio n. 28
0
    def _decoder_step(
            self, last_actions: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        embed_actions = self.EMB({"tokens": last_actions})
        batch_size = embed_actions.size(0)
        # Update stack given draft pointer information

        draft_head = state["encoded_draft"][torch.arange(batch_size),
                                            state["draft_pointer"]]
        query = torch.cat([state["stream_hidden"], draft_head], dim=1)
        attend = self.ATTN(query, state["encoded_triple"],
                           state["triple_mask"])
        attended_triple = util.weighted_sum(state["encoded_triple"], attend)

        is_added = torch.stack([last_actions != tok
                                for tok in self.SYMBOL]).all(dim=0)
        draft_head[is_added] = self.ADD(embed_actions[is_added])

        hs, cs = self.STREAM(torch.cat((draft_head, attended_triple), dim=-1),
                             (state["stream_hidden"], state["stream_context"]))
        drop_mask = (last_actions != self.DROP).unsqueeze(1).float()
        hx = drop_mask * hs + (-drop_mask + 1) * state["stream_hidden"]
        cx = drop_mask * cs + (-drop_mask + 1) * state["stream_context"]
        state["stream_hidden"], state["stream_context"] = hx, cx

        # Update Pointer
        move_forward = ((last_actions == self.KEEP) |
                        (last_actions == self.DROP)).long()

        state["draft_pointer"] = state["draft_pointer"] + move_forward
        # Simple masking for pointer
        state["draft_pointer"] = torch.min(state["draft_pointer"],
                                           state["end_point"])

        is_ended = state["end_point"] == state["draft_pointer"]
        state["action_mask"][is_ended, self.KEEP] = 0
        state["action_mask"][is_ended, self.DROP] = 0
        state["action_mask"][is_ended, self.END] = 1

        return state
Esempio n. 29
0
    def summary_vector(self, encoding, mask, in_type="passage"):
        """
        In NABERT (and in NAQANET), a 'summary_vector' is created for some entities, such as the
        passage or the question. This vector is created as a weighted sum of the elements of the
        entity, e.g. the passage summary vector is a weighted sum of the passage tokens.

        The specific weighting for every entity type is a learned.

        Parameters
        ----------
        encoding : BERT's output layer
        mask : a Tensor with 1s only at the positions relevant to ``in_type``
        in_type : the entity we want to summarize, e.g the passage

        Returns
        -------
        The summary vector according to ``in_type``.
        """
        if in_type == "passage":
            # Shape: (batch_size, seqlen)
            alpha = self._passage_weights_predictor(encoding).squeeze()
        elif in_type == "question":
            # Shape: (batch_size, seqlen)
            alpha = self._question_weights_predictor(encoding).squeeze()
        elif in_type == "arithmetic":
            # Shape: (batch_size, seqlen)
            alpha = self._arithmetic_weights_predictor(encoding).squeeze()
        elif in_type == "multiple_spans":
            #TODO: currenttly not using it...
            alpha = self._multispan_weights_predictor(encoding).squeeze()
        else:
            # Shape: (batch_size, #num of numbers, seqlen)
            alpha = torch.zeros(encoding.shape[:-1], device=encoding.device)
        # Shape: (batch_size, seqlen)
        # (batch_size, #num of numbers, seqlen) for numbers
        alpha = masked_softmax(alpha, mask)
        # Shape: (batch_size, out)
        # (batch_size, #num of numbers, out) for numbers
        h = util.weighted_sum(encoding, alpha)
        return h
Esempio n. 30
0
    def _adv_arithmetic_module(self, summary_vector, maxlen, options,
                               options_mask, bert_out, bert_mask):
        # summary_vector: (batch, bert_dim)
        # options : (batch, opnumlen, bert_dim)
        # options_mask : (batch, opnumlen)
        # bert_out : (batch, seqlen, bert_dim)
        # bert_mask: (batch, seqlen)

        # summary_vector : (batch, explen, bert_dim)
        summary_vector = summary_vector.unsqueeze(1).expand(-1, maxlen, -1)

        # out: (batch, explen, rnn_hdim)
        out, _ = self.rnn(summary_vector)
        out = self.rnndropout(out)
        out = self.Wst(out)

        # alpha : (batch, explen, seqlen)
        alpha = torch.bmm(out, bert_out.transpose(1, 2))
        alpha = util.masked_softmax(alpha, bert_mask)

        # context : (batch, explen, bert_dim)
        context = util.weighted_sum(bert_out, alpha)
        #         context = self.Wcon(context)

        # logits : (batch, explen, opnumlen)
        logits = torch.bmm(context, options.transpose(1, 2))
        logits = util.replace_masked_values(
            logits,
            options_mask.unsqueeze(1).expand_as(logits), -1e7)

        number_mask = options_mask.clone()
        number_mask[:, :self.num_ops] = 0
        op_mask = options_mask.clone()
        op_mask[:, self.num_ops:] = 0

        best_expression = beam_search(self.arithmetic_K, logits.softmax(-1),\
                                                number_mask, op_mask, self.END, self.num_ops)

        return logits, best_expression[0]
Esempio n. 31
0
 def summary_vector(self, encoding, mask, in_type="passage"):
     if in_type == "passage":
         # Shape: (batch_size, seqlen)
         alpha = self._passage_weights_predictor(encoding).squeeze()
     elif in_type == "question":
         # Shape: (batch_size, seqlen)
         alpha = self._question_weights_predictor(encoding).squeeze()
     elif in_type == "arithmetic":
         # Shape: (batch_size, seqlen)
         alpha = self._arithmetic_weights_predictor(encoding).squeeze()
     else:
         # Shape: (batch_size, #num of numbers, seqlen)
         alpha = torch.zeros(encoding.shape[:-1], device=encoding.device)
         if self.number_rep == 'attention':
             alpha = self._number_weights_predictor(encoding).squeeze()
     # Shape: (batch_size, seqlen)
     # (batch_size, #num of numbers, seqlen) for numbers
     alpha = masked_softmax(alpha, mask)
     # Shape: (batch_size, out)
     # (batch_size, #num of numbers, out) for numbers
     h = util.weighted_sum(encoding, alpha)
     return h
Esempio n. 32
0
 def test_weighted_sum_handles_higher_order_input(self):
     batch_size = 1
     length_1 = 5
     length_2 = 6
     length_3 = 2
     embedding_dim = 4
     sentence_array = numpy.random.rand(batch_size, length_1, length_2,
                                        length_3, embedding_dim)
     attention_array = numpy.random.rand(batch_size, length_1, length_2,
                                         length_3)
     sentence_tensor = Variable(torch.from_numpy(sentence_array).float())
     attention_tensor = Variable(torch.from_numpy(attention_array).float())
     aggregated_array = weighted_sum(sentence_tensor,
                                     attention_tensor).data.numpy()
     assert aggregated_array.shape == (batch_size, length_1, length_2,
                                       embedding_dim)
     expected_array = (
         attention_array[0, 3, 2, 0] * sentence_array[0, 3, 2, 0] +
         attention_array[0, 3, 2, 1] * sentence_array[0, 3, 2, 1])
     numpy.testing.assert_almost_equal(aggregated_array[0, 3, 2],
                                       expected_array,
                                       decimal=5)
Esempio n. 33
0
    def _prepare_output_projections(
            self, last_predictions: torch.Tensor,
            state: Dict[str, torch.Tensor], his_sym: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:

        encoder_outputs = state[
            "encoder_outputs"]  # bs, seq_len, encoder_output_dim
        source_mask = state["source_mask"]  # bs * seq_len
        decoder_hidden = state["decoder_hidden"]  # bs, decoder_output_dim
        decoder_context = state["decoder_context"]  # bs * decoder_output

        embedded_input = self._target_embedder(
            last_predictions)  # bs * target_embedding
        decoder_input = torch.cat((embedded_input, state['context_hidden']),
                                  -1)
        if self._attention:  # 如果加了seq_to_seq attention
            input_weights = self._attention(
                decoder_hidden, encoder_outputs,
                source_mask.float())  # bs * seq_len
            attended_input = util.weighted_sum(
                encoder_outputs, input_weights)  # bs * encoder_output
            decoder_input = torch.cat(
                (attended_input, embedded_input),
                -1)  # bs * (decoder_output + target_embedding)

        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input, (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden  # bs * hidden
        state["decoder_context"] = decoder_context

        # output_projections = self._output_projection_layer(torch.cat((decoder_hidden,graph_hidden),-1))
        output_projections = self._output_projection_layer(decoder_hidden)
        # sz = output_projections.size(0)
        # for b in range(sz):
        #     for k,li in enumerate(self.idx_to_vocab_list):
        #         if his_sym[b][k].item() == 1:
        #             output_projections[b][li] = 1e-9
        return output_projections, state
Esempio n. 34
0
  def forward_span(self, ds_name, dialog_repr, repeated_ds_embeddings, context_masks, span_labels=None, spans_start = None, spans_end = None):
    batch_size, max_dialog_len = context_masks.size()
    ds_dialog_sim = self._ds_dialog_attention(self._dropout(repeated_ds_embeddings), self._dropout(dialog_repr))
    ds_dialog_att = util.masked_softmax(ds_dialog_sim.view(-1, max_dialog_len), context_masks.view(-1, max_dialog_len))
    ds_dialog_att = ds_dialog_att.view(batch_size, max_dialog_len)
    ds_dialog_repr = util.weighted_sum(dialog_repr, ds_dialog_att)
    ds_dialog_repr = ds_dialog_repr + repeated_ds_embeddings.squeeze(1)
    span_label_logits = self._span_label_predictor(F.relu(self._dropout(ds_dialog_repr)))
    span_label_prediction = torch.argmax(span_label_logits, dim=1)
    span_label_loss = 0.0
    if span_labels is not None:
      span_label_loss = self._cross_entropy(span_label_logits, span_labels) # loss averaged by #turn
      self._accuracy.span_label_acc(ds_name, span_label_logits, span_labels, span_labels != -1)
    loss = span_label_loss

    w = self._span_prediction_layer(self._dropout(ds_dialog_repr)).unsqueeze(1)
    span_start_repr = self._span_start_encoder(self._dropout(dialog_repr))
    span_start_logits = torch.bmm(w, span_start_repr.transpose(1,2)).squeeze(1)
    span_start_probs = util.masked_softmax(span_start_logits, context_masks)
    span_start_logits = util.replace_masked_values(span_start_logits, context_masks.to(dtype=torch.int8), -1e7)

    span_end_repr = self._span_end_encoder(self._dropout(span_start_repr))
    span_end_logits = torch.bmm(w, span_end_repr.transpose(1,2)).squeeze(1)
    span_end_probs = util.masked_softmax(span_end_logits, context_masks)
    span_end_logits = util.replace_masked_values(span_end_logits, context_masks.to(dtype=torch.int8), -1e7)

    best_span = self.get_best_span(span_start_logits, span_end_logits)
    best_span = best_span.view(batch_size, -1)

    spans_loss = 0.0
    if spans_start is not None:
      spans_loss = self._cross_entropy(span_start_logits, spans_start)
      self._accuracy.span_start_acc(ds_name, span_start_logits, spans_start, spans_start != -1)
      spans_loss += self._cross_entropy(span_end_logits, spans_end)
      self._accuracy.span_end_acc(ds_name, span_end_logits, spans_end, spans_end != -1)
    loss += spans_loss

    return loss, (span_label_prediction, best_span)
Esempio n. 35
0
    def forward(self,
                tokens: torch.Tensor,
                label: torch.Tensor = None) -> Dict[str, torch.Tensor]:

        mask = get_text_field_mask(tokens)

        embedding = self._word_embedding(tokens)
        hidden = self._encoder(embedding, mask)
        attention = self._attention(hidden, mask)
        logit = self._output(weighted_sum(hidden, attention))

        output_dict: Dict[str, torch.Tensor] = {
            "logit": logit,
            "attention": attention
        }

        if label is not None:
            output_dict["loss"] = self._loss(logit, label)

            for eval_metric in self._metrics.values():
                eval_metric(logit, label)

        return output_dict
Esempio n. 36
0
    def _decode_step_output(
            self,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.LongTensor:
        encoder_outputs_mask = encoder_outputs_mask.float()

        #decribe("encoder_outputs_mask", encoder_outputs_mask)

        input_weights_e = self._decoder_attention(decoder_hidden_state,
                                                  encoder_outputs,
                                                  encoder_outputs_mask)

        #decribe("input_weights_e", input_weights_e)

        input_weights_a = masked_softmax(input_weights_e, encoder_outputs_mask)

        #decribe("input_weights_a", input_weights_a)
        attended_input = weighted_sum(encoder_outputs, input_weights_a)

        #decribe("attended_input", attended_input)
        return input_weights_a, torch.cat(
            (decoder_hidden_state, attended_input), -1)
Esempio n. 37
0
similarity_function = LinearSimilarity(
      combination = "x,y,x*y",
      tensor_1_dim =  200,
      tensor_2_dim = 200)

matrix_attention = LegacyMatrixAttention(similarity_function)

passage_question_similarity = matrix_attention(encoded_passage, encoded_question)
# Shape: (batch_size, passage_length, question_length)
print ("passage question similarity: ", passage_question_similarity.shape)


# Shape: (batch_size, passage_length, question_length)
passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
# Shape: (batch_size, passage_length, encoding_dim)
passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

# We replace masked values with something really negative here, so they don't affect the
# max below.
masked_similarity = util.replace_masked_values(passage_question_similarity,
                                               question_mask.unsqueeze(1),
                                               -1e7)
# Shape: (batch_size, passage_length)
question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
# Shape: (batch_size, passage_length)
question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
# Shape: (batch_size, encoding_dim)
question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
# Shape: (batch_size, passage_length, encoding_dim)
tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                            passage_length,
    def _get_initial_rnn_and_grammar_state(self,
                                           question: Dict[str, torch.LongTensor],
                                           table: Dict[str, torch.LongTensor],
                                           world: List[WikiTablesWorld],
                                           actions: List[List[ProductionRule]],
                                           outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet],
                                                                             List[LambdaGrammarStatelet]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the
        decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state
Esempio n. 39
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits,
                       "label_probs": label_probs,
                       "h2p_attention": h2p_attention,
                       "p2h_attention": p2h_attention}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata]
            output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata]

        return output_dict
    def _get_initial_state_and_scores(self,
                                      question: Dict[str, torch.LongTensor],
                                      table: Dict[str, torch.LongTensor],
                                      world: List[WikiTablesWorld],
                                      actions: List[List[ProductionRuleArray]],
                                      example_lisp_string: List[str] = None,
                                      add_world_to_initial_state: bool = False,
                                      checklist_states: List[ChecklistState] = None) -> Dict:
        """
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               output_action_embeddings=output_action_embeddings,
                                               action_biases=action_biases,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               world=initial_state_world,
                                               example_lisp_string=example_lisp_string,
                                               checklist_state=checklist_states,
                                               debug_info=None)
        return {"initial_state": initial_state,
                "linking_scores": linking_scores,
                "feature_scores": feature_scores,
                "similarity_scores": question_entity_similarity_max_score}
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(global_attention_logits,
                                                          span_indices,
                                                          flat_span_indices).squeeze(-1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()

        return attended_text_embeddings
Esempio n. 42
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],
                dim=-1
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True
        )

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``.
        label : torch.LongTensor, optional (default = None)
            A variable representing the label for each instance in the batch.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a
            distribution over the label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        text_mask = util.get_text_field_mask(tokens).float()
        # Pop elmo tokens, since elmo embedder should not be present.
        elmo_tokens = tokens.pop("elmo", None)
        embedded_text = self._text_field_embedder(tokens)

        # Add the "elmo" key back to "tokens" if not None, since the tests and the
        # subsequent training epochs rely not being modified during forward()
        if elmo_tokens is not None:
            tokens["elmo"] = elmo_tokens

        # Create ELMo embeddings if applicable
        if self._elmo:
            if elmo_tokens is not None:
                elmo_representations = self._elmo(elmo_tokens)["elmo_representations"]
                # Pop from the end is more performant with list
                if self._use_integrator_output_elmo:
                    integrator_output_elmo = elmo_representations.pop()
                if self._use_input_elmo:
                    input_elmo = elmo_representations.pop()
                assert not elmo_representations
            else:
                raise ConfigurationError(
                        "Model was built to use Elmo, but input text is not tokenized for Elmo.")

        if self._use_input_elmo:
            embedded_text = torch.cat([embedded_text, input_elmo], dim=-1)

        dropped_embedded_text = self._embedding_dropout(embedded_text)
        pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text)
        encoded_tokens = self._encoder(pre_encoded_text, text_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.last_dim_softmax(attention_logits, text_mask)
        encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([encoded_tokens,
                                      encoded_tokens - encoded_text,
                                      encoded_tokens * encoded_text], 2)
        integrated_encodings = self._integrator(integrator_input, text_mask)

        # Concatenate ELMo representations to integrated_encodings if specified
        if self._use_integrator_output_elmo:
            integrated_encodings = torch.cat([integrated_encodings,
                                              integrator_output_elmo], dim=-1)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
                integrated_encodings, text_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
                integrated_encodings, text_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(text_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
                integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits, text_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights)

        pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(pooled_representations)

        logits = self._output_layer(pooled_representations_dropped)
        class_probabilities = F.softmax(logits, dim=-1)

        output_dict = {'logits': logits, 'class_probabilities': class_probabilities}
        if label is not None:
            loss = self.loss(logits, label)
            for metric in self.metrics.values():
                metric(logits, label)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 44
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.Tensor,
                mask: torch.LongTensor = None) -> torch.FloatTensor:
        """
        Parameters
        ----------
        inputs : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, timesteps, input_dim)
        mask : ``torch.FloatTensor``, optional (default = None).
            A tensor of shape (batch_size, timesteps).

        Returns
        -------
        A tensor of shape (batch_size, timesteps, output_projection_dim),
        where output_projection_dim = input_dim by default.
        """
        num_heads = self._num_heads

        batch_size, timesteps, _ = inputs.size()
        if mask is None:
            mask = inputs.new_ones(batch_size, timesteps)

        # Shape (batch_size, timesteps, 2 * attention_dim + values_dim)
        combined_projection = self._combined_projection(inputs)
        # split by attention dim - if values_dim > attention_dim, we will get more
        # than 3 elements returned. All of the rest are the values vector, so we
        # just concatenate them back together again below.
        queries, keys, *values = combined_projection.split(self._attention_dim, -1)
        queries = queries.contiguous()
        keys = keys.contiguous()
        values = torch.cat(values, -1).contiguous()
        # Shape (num_heads * batch_size, timesteps, values_dim / num_heads)
        values_per_head = values.view(batch_size, timesteps, num_heads, int(self._values_dim/num_heads))
        values_per_head = values_per_head.transpose(1, 2).contiguous()
        values_per_head = values_per_head.view(batch_size * num_heads, timesteps, int(self._values_dim/num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        queries_per_head = queries.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
        queries_per_head = queries_per_head.transpose(1, 2).contiguous()
        queries_per_head = queries_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

        # Shape (num_heads * batch_size, timesteps, attention_dim / num_heads)
        keys_per_head = keys.view(batch_size, timesteps, num_heads, int(self._attention_dim/num_heads))
        keys_per_head = keys_per_head.transpose(1, 2).contiguous()
        keys_per_head = keys_per_head.view(batch_size * num_heads, timesteps, int(self._attention_dim/num_heads))

        # shape (num_heads * batch_size, timesteps, timesteps)
        scaled_similarities = torch.bmm(queries_per_head, keys_per_head.transpose(1, 2)) / self._scale

        # shape (num_heads * batch_size, timesteps, timesteps)
        # Normalise the distributions, using the same mask for all heads.
        attention = last_dim_softmax(scaled_similarities, mask.repeat(1, num_heads).view(batch_size * num_heads, timesteps))
        attention = self._attention_dropout(attention)

        # Take a weighted sum of the values with respect to the attention
        # distributions for each element in the num_heads * batch_size dimension.
        # shape (num_heads * batch_size, timesteps, values_dim/num_heads)
        outputs = weighted_sum(values_per_head, attention)

        # Reshape back to original shape (batch_size, timesteps, values_dim)
        # shape (batch_size, num_heads, timesteps, values_dim/num_heads)
        outputs = outputs.view(batch_size, num_heads, timesteps, int(self._values_dim / num_heads))
        # shape (batch_size, timesteps, num_heads, values_dim/num_heads)
        outputs = outputs.transpose(1, 2).contiguous()
        # shape (batch_size, timesteps, values_dim)
        outputs = outputs.view(batch_size, timesteps, self._values_dim)

        # Project back to original input size.
        # shape (batch_size, timesteps, input_size)
        outputs = self._output_projection(outputs)
        return outputs
Esempio n. 46
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict
Esempio n. 47
0
    def forward(self, s1, s2):
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        s1 : Dict[str, torch.LongTensor]
            From a ``TextField``.
        s2 : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this s2 contains the answer to the
            s1, and predicts the beginning and ending positions of the answer within the
            s2.

        Returns
        -------
        pair_rep : torch.FloatTensor?
            Tensor representing the final output of the BiDAF model
            to be plugged into the next module

        """
        s1_embs = self._highway_layer(self._text_field_embedder(s1))
        s2_embs = self._highway_layer(self._text_field_embedder(s2))
        if self._elmo is not None:
            s1_elmo_embs = self._elmo(s1['elmo'])
            s2_elmo_embs = self._elmo(s2['elmo'])
            if "words" in s1:
                s1_embs = torch.cat([s1_embs, s1_elmo_embs['elmo_representations'][0]], dim=-1)
                s2_embs = torch.cat([s2_embs, s2_elmo_embs['elmo_representations'][0]], dim=-1)
            else:
                s1_embs = s1_elmo_embs['elmo_representations'][0]
                s2_embs = s2_elmo_embs['elmo_representations'][0]
        if self._cove is not None:
            s1_lens = torch.ne(s1['words'], self.pad_idx).long().sum(dim=-1).data
            s2_lens = torch.ne(s2['words'], self.pad_idx).long().sum(dim=-1).data
            s1_cove_embs = self._cove(s1['words'], s1_lens)
            s1_embs = torch.cat([s1_embs, s1_cove_embs], dim=-1)
            s2_cove_embs = self._cove(s2['words'], s2_lens)
            s2_embs = torch.cat([s2_embs, s2_cove_embs], dim=-1)
        s1_embs = self._dropout(s1_embs)
        s2_embs = self._dropout(s2_embs)

        if self._mask_lstms:
            s1_mask = s1_lstm_mask = util.get_text_field_mask(s1).float()
            s2_mask = s2_lstm_mask = util.get_text_field_mask(s2).float()
            s1_mask_2 = util.get_text_field_mask(s1).float()
            s2_mask_2 = util.get_text_field_mask(s2).float()
        else:
            s1_lstm_mask, s2_lstm_mask, s2_lstm_mask_2 = None, None, None

        s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask)
        s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask)

        # Similarity matrix
        # Shape: (batch_size, s2_length, s1_length)
        similarity_mat = self._matrix_attention(s2_enc, s1_enc)

        # s2 representation
        # Shape: (batch_size, s2_length, s1_length)
        s2_s1_attention = util.last_dim_softmax(similarity_mat, s1_mask)
        # Shape: (batch_size, s2_length, encoding_dim)
        s2_s1_vectors = util.weighted_sum(s1_enc, s2_s1_attention)
        # batch_size, seq_len, 4*enc_dim
        s2_w_context = torch.cat([s2_enc, s2_s1_vectors], 2)
        # s1 representation, using same attn method as for the s2 representation
        s1_s2_attention = util.last_dim_softmax(similarity_mat.transpose(1, 2).contiguous(), s2_mask)
        # Shape: (batch_size, s1_length, encoding_dim)
        s1_s2_vectors = util.weighted_sum(s2_enc, s1_s2_attention)
        s1_w_context = torch.cat([s1_enc, s1_s2_vectors], 2)
        if self._elmo is not None and self._deep_elmo:
            s1_w_context = torch.cat([s1_w_context, s1_elmo_embs['elmo_representations'][1]], dim=-1)
            s2_w_context = torch.cat([s2_w_context, s2_elmo_embs['elmo_representations'][1]], dim=-1)
        s1_w_context = self._dropout(s1_w_context)
        s2_w_context = self._dropout(s2_w_context)

        modeled_s2 = self._dropout(self._modeling_layer(s2_w_context, s2_lstm_mask))
        s2_mask_2 = s2_mask_2.unsqueeze(dim=-1)
        modeled_s2.data.masked_fill_(1 - s2_mask_2.byte().data, -float('inf'))
        s2_enc_attn = modeled_s2.max(dim=1)[0]
        modeled_s1 = self._dropout(self._modeling_layer(s1_w_context, s1_lstm_mask))
        s1_mask_2 = s1_mask_2.unsqueeze(dim=-1)
        modeled_s1.data.masked_fill_(1 - s1_mask_2.byte().data, -float('inf'))
        s1_enc_attn = modeled_s1.max(dim=1)[0]

        return torch.cat([s1_enc_attn, s2_enc_attn, torch.abs(s1_enc_attn - s2_enc_attn),
                          s1_enc_attn * s2_enc_attn], 1)