Esempio n. 1
0
    def forward(self, s1, s2):
        # pylint: disable=arguments-differ
        """ """
        # Embeddings
        s1_embs = self._highway_layer(self._text_field_embedder(s1))
        s2_embs = self._highway_layer(self._text_field_embedder(s2))
        if self._elmo is not None:
            s1_elmo_embs = self._elmo(s1['elmo'])
            s2_elmo_embs = self._elmo(s2['elmo'])
            if "words" in s1:
                s1_embs = torch.cat([s1_embs, s1_elmo_embs['elmo_representations'][0]], dim=-1)
                s2_embs = torch.cat([s2_embs, s2_elmo_embs['elmo_representations'][0]], dim=-1)
            else:
                s1_embs = s1_elmo_embs['elmo_representations'][0]
                s2_embs = s2_elmo_embs['elmo_representations'][0]

        if self._cove is not None:
            s1_lens = torch.ne(s1['words'], self.pad_idx).long().sum(dim=-1).data
            s2_lens = torch.ne(s2['words'], self.pad_idx).long().sum(dim=-1).data
            s1_cove_embs = self._cove(s1['words'], s1_lens)
            s1_embs = torch.cat([s1_embs, s1_cove_embs], dim=-1)
            s2_cove_embs = self._cove(s2['words'], s2_lens)
            s2_embs = torch.cat([s2_embs, s2_cove_embs], dim=-1)
        s1_embs = self._dropout(s1_embs)
        s2_embs = self._dropout(s2_embs)

        # Set up masks
        s1_mask = util.get_text_field_mask(s1)
        s2_mask = util.get_text_field_mask(s2)
        s1_lstm_mask = s1_mask.float() if self._mask_lstms else None
        s2_lstm_mask = s2_mask.float() if self._mask_lstms else None

        # Sentence encodings with LSTMs
        s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask)
        s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask)
        if self._elmo is not None and len(s1_elmo_embs['elmo_representations']) > 1:
            s1_enc = torch.cat([s1_enc, s1_elmo_embs['elmo_representations'][1]], dim=-1)
            s2_enc = torch.cat([s2_enc, s2_elmo_embs['elmo_representations'][1]], dim=-1)
        s1_enc = self._dropout(s1_enc)
        s2_enc = self._dropout(s2_enc)

        # Max pooling
        s1_mask = s1_mask.unsqueeze(dim=-1)
        s2_mask = s2_mask.unsqueeze(dim=-1)
        s1_enc.data.masked_fill_(1 - s1_mask.byte().data, -float('inf'))
        s2_enc.data.masked_fill_(1 - s2_mask.byte().data, -float('inf'))
        s1_enc, _ = s1_enc.max(dim=1)
        s2_enc, _ = s2_enc.max(dim=1)

        return torch.cat([s1_enc, s2_enc, torch.abs(s1_enc - s2_enc), s1_enc * s2_enc], 1)
Esempio n. 2
0
 def test_get_text_field_mask_returns_mask_key(self):
     text_field_tensors = {
             "tokens": torch.LongTensor([[3, 4, 5, 0, 0], [1, 2, 0, 0, 0]]),
             "mask": torch.LongTensor([[0, 0, 1]])
             }
     assert_almost_equal(util.get_text_field_mask(text_field_tensors).numpy(),
                         [[0, 0, 1]])
Esempio n. 3
0
    def forward(self, sent):
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sent : Dict[str, torch.LongTensor]
            From a ``TextField``.

        Returns
        -------
        """
        sent_embs = self._highway_layer(self._text_field_embedder(sent))
        if self._cove is not None:
            sent_lens = torch.ne(sent['words'], self.pad_idx).long().sum(dim=-1).data
            sent_cove_embs = self._cove(sent['words'], sent_lens)
            sent_embs = torch.cat([sent_embs, sent_cove_embs], dim=-1)
        if self._elmo is not None:
            elmo_embs = self._elmo(sent['elmo'])
            if "words" in sent:
                sent_embs = torch.cat([sent_embs, elmo_embs['elmo_representations'][0]], dim=-1)
            else:
                sent_embs = elmo_embs['elmo_representations'][0]
        sent_embs = self._dropout(sent_embs)

        sent_mask = util.get_text_field_mask(sent).float()
        sent_lstm_mask = sent_mask if self._mask_lstms else None

        sent_enc = self._phrase_layer(sent_embs, sent_lstm_mask)
        if self._elmo is not None and len(elmo_embs['elmo_representations']) > 1:
            sent_enc = torch.cat([sent_enc, elmo_embs['elmo_representations'][1]], dim=-1)
        sent_enc = self._dropout(sent_enc)

        sent_mask = sent_mask.unsqueeze(dim=-1)
        sent_enc.data.masked_fill_(1 - sent_mask.byte().data, -float('inf'))
        return sent_enc.max(dim=1)[0]
Esempio n. 4
0
 def test_get_text_field_mask_returns_a_correct_mask_character_only_input(self):
     text_field_tensors = {
             "token_characters": torch.LongTensor([[[1, 2, 3], [3, 0, 1], [2, 1, 0], [0, 0, 0]],
                                                   [[5, 5, 5], [4, 6, 0], [0, 0, 0], [0, 0, 0]]])
             }
     assert_almost_equal(util.get_text_field_mask(text_field_tensors).numpy(),
                         [[1, 1, 1, 0], [1, 1, 0, 0]])
    def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]):
        embedded_input = self._sentence_embedder(sentence)
        # (batch_size, sentence_length)
        sentence_mask = util.get_text_field_mask(sentence).float()

        batch_size = embedded_input.size(0)

        # (batch_size, sentence_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_input, sentence_mask))

        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             sentence_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        attended_sentence, _ = self._decoder_step.attend_on_question(final_encoder_output,
                                                                     encoder_outputs, sentence_mask)
        encoder_outputs_list = [encoder_outputs[i] for i in range(batch_size)]
        sentence_mask_list = [sentence_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 attended_sentence[i],
                                                 encoder_outputs_list,
                                                 sentence_mask_list))
        return initial_rnn_state
Esempio n. 6
0
 def test_get_text_field_mask_returns_a_correct_mask_list_field(self):
     text_field_tensors = {
             "list_tokens": torch.LongTensor([[[1, 2], [3, 0], [2, 0], [0, 0], [0, 0]],
                                              [[5, 0], [4, 6], [0, 0], [0, 0], [0, 0]]])
             }
     actual_mask = util.get_text_field_mask(text_field_tensors, num_wrapping_dims=1).numpy()
     expected_mask = (text_field_tensors['list_tokens'].numpy() > 0).astype('int32')
     assert_almost_equal(actual_mask, expected_mask)
Esempio n. 7
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                tags: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels of shape
            ``(batch_size, num_tokens)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence to be tagged under a 'words' key.

        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.text_field_embedder(tokens)
        batch_size, sequence_length, _ = embedded_text_input.size()
        mask = get_text_field_mask(tokens)
        encoded_text = self.encoder(embedded_text_input, mask)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size,
                                                                          sequence_length,
                                                                          self.num_classes])

        output_dict = {"logits": logits, "class_probabilities": class_probabilities}

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(logits, tags, mask)
            for metric in self.metrics.values():
                metric(logits, tags, mask.float())
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["words"] = [x["words"] for x in metadata]
        return output_dict
Esempio n. 8
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : ``torch.LongTensor``, optional (default = ``None``)
            A torch tensor representing the sequence of integer gold class labels of shape
            ``(batch_size, num_tokens)``.

        Returns
        -------
        An output dictionary consisting of:

        logits : ``torch.FloatTensor``
            The logits that are the output of the ``tag_projection_layer``
        mask : ``torch.LongTensor``
            The text field mask for the input tokens
        tags : ``List[List[str]]``
            The predicted tags using the Viterbi algorithm.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        mask = util.get_text_field_mask(tokens)
        encoded_text = self.encoder(embedded_text_input, mask)

        logits = self.tag_projection_layer(encoded_text)
        predicted_tags = self.crf.viterbi_tags(logits, mask)

        output = {"logits": logits, "mask": mask, "tags": predicted_tags}

        if tags is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self.crf(logits, tags, mask)
            output["loss"] = -log_likelihood

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the `span_metric`
            class_probabilities = logits * 0.
            for i, instance_tags in enumerate(predicted_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            self.span_metric(class_probabilities, tags, mask)

        return output
Esempio n. 9
0
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(sentence)
        embeddings = self.word_embeddings(sentence)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
Esempio n. 10
0
 def _update_recall(self,
                    all_top_k_predictions: torch.Tensor,
                    target_tokens: Dict[str, torch.LongTensor],
                    target_recall: UnigramRecall) -> None:
     targets = target_tokens["tokens"]
     target_mask = get_text_field_mask(target_tokens)
     # See comment in _get_loss.
     # TODO(brendanr): Do we need contiguous here?
     relevant_targets = targets[:, 1:].contiguous()
     relevant_mask = target_mask[:, 1:].contiguous()
     target_recall(
             all_top_k_predictions,
             relevant_targets,
             relevant_mask,
             self._end_index
     )
Esempio n. 11
0
 def get_sample_encoded_output(self):
     """
     Returns the encoded vector for a sample event.
     """
     for instance in self.dataset:
         cur_text_field = instance.fields["source"]
         text = [token.text for token in cur_text_field.tokens]
         if text == ["@start@", "personx", "calls", "personx", "'s", "brother", "@end@"]:
             sample_text_field = cur_text_field
             break
     source = sample_text_field.as_tensor(sample_text_field.get_padding_lengths())
     source['tokens'] = source['tokens'].unsqueeze(0)
     embedded_input = self.trained_model._embedding_dropout(
             self.trained_model._source_embedder(source)
     )
     source_mask = get_text_field_mask(source)
     return self.trained_model._encoder(embedded_input, source_mask)
    def forward(self,  # type: ignore
                inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``
            Shape ``(batch_size, timesteps, ...)`` of token ids representing the current batch.
            These must have been produced using the same indexer the LM was trained on.

        Returns
        -------
        The bidirectional language model representations for the input sequence, shape
        ``(batch_size, timesteps, embedding_dim)``
        """
        # pylint: disable=arguments-differ
        if self._bos_indices is not None:
            mask = get_text_field_mask({"": inputs})
            inputs, mask = add_sentence_boundary_token_ids(
                    inputs, mask, self._bos_indices, self._eos_indices
            )

        source = {self._token_name: inputs}
        result_dict = self._lm(source)

        # shape (batch_size, timesteps, embedding_size)
        noncontextual_token_embeddings = result_dict["noncontextual_token_embeddings"]
        contextual_embeddings = result_dict["lm_embeddings"]

        # Typically the non-contextual embeddings are smaller than the contextualized embeddings.
        # Since we're averaging all the layers we need to make their dimensions match. Simply
        # repeating the non-contextual embeddings is a crude, but effective, way to do this.
        duplicated_character_embeddings = torch.cat(
                [noncontextual_token_embeddings] * self._character_embedding_duplication_count, -1
        )
        averaged_embeddings = self._scalar_mix(
                [duplicated_character_embeddings] + contextual_embeddings
        )

        # Add dropout
        averaged_embeddings = self._dropout(averaged_embeddings)
        if self._remove_bos_eos:
            averaged_embeddings, _ = remove_sentence_boundaries(
                    averaged_embeddings, result_dict["mask"]
            )

        return averaged_embeddings
Esempio n. 13
0
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        #### AllenNLP is designed to operate on batched inputs, but different input sequences have different lengths. Behind the scenes AllenNLP is padding the shorter inputs so that the batch has uniform shape, which means our computations need to use a mask to exclude the padding. Here we just use the utility function <code>get_text_field_mask</code>, which returns a tensor of 0s and 1s corresponding to the padded and unpadded locations.
        mask = get_text_field_mask(sentence)
        #### We start by passing the <code>sentence</code> tensor (each sentence a sequence of token ids) to the <code>word_embeddings</code> module, which converts each sentence into a sequence of embedded tensors.
        embeddings = self.word_embeddings(sentence)
        #### We next pass the embedded tensors (and the mask) to the LSTM, which produces a sequence of encoded outputs.
        encoder_out = self.encoder(embeddings, mask)
        #### Finally, we pass each encoded output tensor to the feedforward layer to produce logits corresponding to the various tags.
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}

        #### As before, the labels were optional, as we might want to run this model to make predictions on unlabeled data. If we do have labels, then we use them to update our accuracy metric and compute the "loss" that goes in our output.
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
Esempio n. 14
0
    def greedy_search(self,
                      final_encoder_output: torch.LongTensor,
                      target_tokens: Dict[str, torch.LongTensor],
                      target_embedder: Embedding,
                      decoder_cell: GRUCell,
                      output_projection_layer: Linear) -> torch.FloatTensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the cross entropy between this sequence and ``target_tokens``.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()`` applied on some target ``TextField``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._get_num_decoding_steps(target_tokens)
        targets = target_tokens["tokens"]
        decoder_hidden = final_encoder_output
        step_logits = []
        for timestep in range(num_decoding_steps):
            # See https://github.com/allenai/allennlp/issues/1134.
            input_choices = targets[:, timestep]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
        # (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        target_mask = get_text_field_mask(target_tokens)
        return self._get_loss(logits, targets, target_mask)
Esempio n. 15
0
    def forward(self, question):
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.

        Returns
        -------
        pair_rep : torch.FloatTensor?
            Tensor representing the final output of the BiDAF model
            to be plugged into the next module

        """
        word_char_embs = self._text_field_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        return word_char_embs.mean(1) # need to get # nonzero elts
    def _get_initial_state_and_scores(self,
                                      question: Dict[str, torch.LongTensor],
                                      table: Dict[str, torch.LongTensor],
                                      world: List[WikiTablesWorld],
                                      actions: List[List[ProductionRuleArray]],
                                      example_lisp_string: List[str] = None,
                                      add_world_to_initial_state: bool = False,
                                      checklist_states: List[ChecklistState] = None) -> Dict:
        """
        Does initial preparation and creates an intiial state for both the semantic parsers. Note
        that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to
        pass it.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types)
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._type_params(entity_types.float())
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        initial_score = embedded_question.data.new_zeros(batch_size)

        action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions)

        _, num_entities, num_question_tokens = linking_scores.size()
        flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores,
                                                                                     world,
                                                                                     actions)
        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnState(final_encoder_output[i],
                                              memory_cell[i],
                                              self._first_action_embedding,
                                              self._first_attended_question,
                                              encoder_output_list,
                                              question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i], actions[i])
                                 for i in range(batch_size)]
        initial_state_world = world if add_world_to_initial_state else None
        initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)),
                                               action_history=[[] for _ in range(batch_size)],
                                               score=initial_score_list,
                                               rnn_state=initial_rnn_state,
                                               grammar_state=initial_grammar_state,
                                               action_embeddings=action_embeddings,
                                               output_action_embeddings=output_action_embeddings,
                                               action_biases=action_biases,
                                               action_indices=action_indices,
                                               possible_actions=actions,
                                               flattened_linking_scores=flattened_linking_scores,
                                               actions_to_entities=actions_to_entities,
                                               entity_types=entity_type_dict,
                                               world=initial_state_world,
                                               example_lisp_string=example_lisp_string,
                                               checklist_state=checklist_states,
                                               debug_info=None)
        return {"initial_state": initial_state,
                "linking_scores": linking_scores,
                "feature_scores": feature_scores,
                "similarity_scores": question_entity_similarity_max_score}
Esempio n. 17
0
    def forward(self, s1, s2):
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        s1 : Dict[str, torch.LongTensor]
            From a ``TextField``.
        s2 : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this s2 contains the answer to the
            s1, and predicts the beginning and ending positions of the answer within the
            s2.

        Returns
        -------
        pair_rep : torch.FloatTensor?
            Tensor representing the final output of the BiDAF model
            to be plugged into the next module

        """
        s1_embs = self._highway_layer(self._text_field_embedder(s1))
        s2_embs = self._highway_layer(self._text_field_embedder(s2))
        if self._elmo is not None:
            s1_elmo_embs = self._elmo(s1['elmo'])
            s2_elmo_embs = self._elmo(s2['elmo'])
            if "words" in s1:
                s1_embs = torch.cat([s1_embs, s1_elmo_embs['elmo_representations'][0]], dim=-1)
                s2_embs = torch.cat([s2_embs, s2_elmo_embs['elmo_representations'][0]], dim=-1)
            else:
                s1_embs = s1_elmo_embs['elmo_representations'][0]
                s2_embs = s2_elmo_embs['elmo_representations'][0]
        if self._cove is not None:
            s1_lens = torch.ne(s1['words'], self.pad_idx).long().sum(dim=-1).data
            s2_lens = torch.ne(s2['words'], self.pad_idx).long().sum(dim=-1).data
            s1_cove_embs = self._cove(s1['words'], s1_lens)
            s1_embs = torch.cat([s1_embs, s1_cove_embs], dim=-1)
            s2_cove_embs = self._cove(s2['words'], s2_lens)
            s2_embs = torch.cat([s2_embs, s2_cove_embs], dim=-1)
        s1_embs = self._dropout(s1_embs)
        s2_embs = self._dropout(s2_embs)

        if self._mask_lstms:
            s1_mask = s1_lstm_mask = util.get_text_field_mask(s1).float()
            s2_mask = s2_lstm_mask = util.get_text_field_mask(s2).float()
            s1_mask_2 = util.get_text_field_mask(s1).float()
            s2_mask_2 = util.get_text_field_mask(s2).float()
        else:
            s1_lstm_mask, s2_lstm_mask, s2_lstm_mask_2 = None, None, None

        s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask)
        s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask)

        # Similarity matrix
        # Shape: (batch_size, s2_length, s1_length)
        similarity_mat = self._matrix_attention(s2_enc, s1_enc)

        # s2 representation
        # Shape: (batch_size, s2_length, s1_length)
        s2_s1_attention = util.last_dim_softmax(similarity_mat, s1_mask)
        # Shape: (batch_size, s2_length, encoding_dim)
        s2_s1_vectors = util.weighted_sum(s1_enc, s2_s1_attention)
        # batch_size, seq_len, 4*enc_dim
        s2_w_context = torch.cat([s2_enc, s2_s1_vectors], 2)
        # s1 representation, using same attn method as for the s2 representation
        s1_s2_attention = util.last_dim_softmax(similarity_mat.transpose(1, 2).contiguous(), s2_mask)
        # Shape: (batch_size, s1_length, encoding_dim)
        s1_s2_vectors = util.weighted_sum(s2_enc, s1_s2_attention)
        s1_w_context = torch.cat([s1_enc, s1_s2_vectors], 2)
        if self._elmo is not None and self._deep_elmo:
            s1_w_context = torch.cat([s1_w_context, s1_elmo_embs['elmo_representations'][1]], dim=-1)
            s2_w_context = torch.cat([s2_w_context, s2_elmo_embs['elmo_representations'][1]], dim=-1)
        s1_w_context = self._dropout(s1_w_context)
        s2_w_context = self._dropout(s2_w_context)

        modeled_s2 = self._dropout(self._modeling_layer(s2_w_context, s2_lstm_mask))
        s2_mask_2 = s2_mask_2.unsqueeze(dim=-1)
        modeled_s2.data.masked_fill_(1 - s2_mask_2.byte().data, -float('inf'))
        s2_enc_attn = modeled_s2.max(dim=1)[0]
        modeled_s1 = self._dropout(self._modeling_layer(s1_w_context, s1_lstm_mask))
        s1_mask_2 = s1_mask_2.unsqueeze(dim=-1)
        modeled_s1.data.masked_fill_(1 - s1_mask_2.byte().data, -float('inf'))
        s1_enc_attn = modeled_s1.max(dim=1)[0]

        return torch.cat([s1_enc_attn, s2_enc_attn, torch.abs(s1_enc_attn - s2_enc_attn),
                          s1_enc_attn * s2_enc_attn], 1)
    def forward(
        self,  # type: ignore
        words: TextFieldTensors,
        pos_tags: torch.LongTensor,
        metadata: List[Dict[str, Any]],
        head_tags: torch.LongTensor = None,
        head_indices: torch.LongTensor = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        words : TextFieldTensors, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer`
            tensors.  At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens":
            Tensor(batch_size, sequence_length)}`. This dictionary will have the same keys as were used
            for the `TokenIndexers` when you created the `TextField` representing your
            sequence.  The dictionary is designed to be passed directly to a `TextFieldEmbedder`,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : `torch.LongTensor`, required
            The output of a `SequenceLabelField` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        metadata : List[Dict[str, Any]], optional (default=None)
            A dictionary of metadata for each batch element which has keys:
                words : `List[str]`, required.
                    The tokens in the original sentence.
                pos : `List[str]`, required.
                    The dependencies POS tags for each word.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape `(batch_size, sequence_length)`.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape `(batch_size, sequence_length)`.

        # Returns

        An output dictionary consisting of:
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        arc_loss : `torch.FloatTensor`
            The loss contribution from the unlabeled arcs.
        loss : `torch.FloatTensor`, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : `torch.FloatTensor`
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : `torch.FloatTensor`
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : `torch.BoolTensor`
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat(
                [embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError(
                "Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)

        predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll = self._parse(
            embedded_text_input, mask, head_tags, head_indices)

        loss = arc_nll + tag_nll

        if head_indices is not None and head_tags is not None:
            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(
                predicted_heads[:, 1:],
                predicted_head_tags[:, 1:],
                head_indices,
                head_tags,
                evaluation_mask,
            )

        output_dict = {
            "heads": predicted_heads,
            "head_tags": predicted_head_tags,
            "arc_loss": arc_nll,
            "tag_loss": tag_nll,
            "loss": loss,
            "mask": mask,
            "words": [meta["words"] for meta in metadata],
            "pos": [meta["pos"] for meta in metadata],
        }

        return output_dict
    def _get_initial_state(self,
                           utterance: Dict[str, torch.LongTensor],
                           worlds: List[AtisWorld],
                           actions: List[List[ProductionRule]],
                           linking_scores: torch.Tensor) -> GrammarBasedState:
        embedded_utterance = self._utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance).float()

        batch_size = embedded_utterance.size(0)
        num_entities = max([len(world.entities) for world in worlds])

        # entity_types: tensor with shape (batch_size, num_entities)
        entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance)

        # (batch_size, num_utterance_tokens, embedding_dim)
        encoder_input = embedded_utterance

        # (batch_size, utterance_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             utterance_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())
        initial_score = embedded_utterance.data.new_zeros(batch_size)

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, utterance_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        initial_score_list = [initial_score[i] for i in range(batch_size)]
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        utterance_mask_list = [utterance_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            if self._decoder_num_layers > 1:
                initial_rnn_state.append(RnnStatelet(final_encoder_output[i].repeat(self._decoder_num_layers, 1),
                                                     memory_cell[i].repeat(self._decoder_num_layers, 1),
                                                     self._first_action_embedding,
                                                     self._first_attended_utterance,
                                                     encoder_output_list,
                                                     utterance_mask_list))
            else:
                initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                     memory_cell[i],
                                                     self._first_action_embedding,
                                                     self._first_attended_utterance,
                                                     encoder_output_list,
                                                     utterance_mask_list))


        initial_grammar_state = [self._create_grammar_state(worlds[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]

        initial_state = GrammarBasedState(batch_indices=list(range(batch_size)),
                                          action_history=[[] for _ in range(batch_size)],
                                          score=initial_score_list,
                                          rnn_state=initial_rnn_state,
                                          grammar_state=initial_grammar_state,
                                          possible_actions=actions,
                                          debug_info=None)
        return initial_state
Esempio n. 20
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))

        # # v5:
        # # remember to set token embeddings in the CONFIG JSON
        # encoded_question = self._dropout(embedded_question)

        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX
        similarity_matrix = self._matrix_attention(encoded_passage,
                                                   encoded_question)

        # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY
        passage_question_attention = util.last_dim_softmax(
            similarity_matrix, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Our custom query2context
        q2c_attention = util.masked_softmax(similarity_matrix,
                                            question_mask,
                                            dim=1).transpose(-1, -2)
        q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention)

        # Now we try the various variants
        # v1:
        # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention)

        # v2:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1]))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2)

        # # v3:
        # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1))
        # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim)

        # v4:
        # Re-application of query2context attention
        # new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs)
        # masked_similarity = util.replace_masked_values(new_similarity_matrix,
        #                                                question_mask.unsqueeze(1),
        #                                                -1e7)
        # # Shape: (batch_size, passage_length)
        # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # # Shape: (batch_size, passage_length)
        # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # # Shape: (batch_size, encoding_dim)
        # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # # Shape: (batch_size, passage_length, encoding_dim)
        # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
        #                                                                             passage_length,
        #                                                                             encoding_dim)
        # #
        # ------- Original variant
        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            similarity_matrix, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # ------- END

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        # original beta combination function
        # final_merged_passage = torch.cat([encoded_passage,
        #                                   passage_question_vectors,
        #                                   encoded_passage * passage_question_vectors,
        #                                   encoded_passage * tiled_question_passage_vector],
        #                                  dim=-1)

        # # v6:
        # final_merged_passage = torch.cat([tiled_question_passage_vector],
        #                                  dim=-1)
        #
        # # v7:
        # final_merged_passage = torch.cat([passage_question_vectors],
        #                                  dim=-1)
        #
        # # v8:
        # final_merged_passage = torch.cat([passage_question_vectors,
        #                                   tiled_question_passage_vector],
        #                                  dim=-1)

        # # v9:
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Esempio n. 21
0
    def forward(
        self,  # type: ignore
        tokens: TextFieldTensors,
        label: torch.LongTensor = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : TextFieldTensors, required
            The output of `TextField.as_array()`.
        label : torch.LongTensor, optional (default = None)
            A variable representing the label for each instance in the batch.

        # Returns

        An output dictionary consisting of:
             - `class_probabilities` (`torch.FloatTensor`) :
                 A tensor of shape `(batch_size, num_classes)` representing a
                 distribution over the label classes for each instance.
             - `loss` (`torch.FloatTensor`, optional) :
                 A scalar loss to be optimised.        """
        text_mask = util.get_text_field_mask(tokens)
        # Pop elmo tokens, since elmo embedder should not be present.
        elmo_tokens = tokens.pop("elmo", None)
        if tokens:
            embedded_text = self._text_field_embedder(tokens)
        else:
            # only using "elmo" for input
            embedded_text = None

        # Add the "elmo" key back to "tokens" if not None, since the tests and the
        # subsequent training epochs rely not being modified during forward()
        if elmo_tokens is not None:
            tokens["elmo"] = elmo_tokens

        # Create ELMo embeddings if applicable
        if self._elmo:
            if elmo_tokens is not None:
                elmo_representations = self._elmo(
                    elmo_tokens["elmo_tokens"])["elmo_representations"]
                # Pop from the end is more performant with list
                if self._use_integrator_output_elmo:
                    integrator_output_elmo = elmo_representations.pop()
                if self._use_input_elmo:
                    input_elmo = elmo_representations.pop()
                assert not elmo_representations
            else:
                raise ConfigurationError(
                    "Model was built to use Elmo, but input text is not tokenized for Elmo."
                )

        if self._use_input_elmo:
            if embedded_text is not None:
                embedded_text = torch.cat([embedded_text, input_elmo], dim=-1)
            else:
                embedded_text = input_elmo

        dropped_embedded_text = self._embedding_dropout(embedded_text)
        pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text)
        encoded_tokens = self._encoder(pre_encoded_text, text_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(
            encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.masked_softmax(attention_logits, text_mask)
        encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([
            encoded_tokens, encoded_tokens - encoded_text,
            encoded_tokens * encoded_text
        ], 2)
        integrated_encodings = self._integrator(integrator_input, text_mask)

        # Concatenate ELMo representations to integrated_encodings if specified
        if self._use_integrator_output_elmo:
            integrated_encodings = torch.cat(
                [integrated_encodings, integrator_output_elmo], dim=-1)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings,
            text_mask.unsqueeze(2),
            util.min_value_of_dtype(integrated_encodings.dtype),
        )
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings,
            text_mask.unsqueeze(2),
            util.max_value_of_dtype(integrated_encodings.dtype),
        )
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(
            text_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
            integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits, text_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings,
                                                self_weights)

        pooled_representations = torch.cat(
            [max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(
            pooled_representations)

        logits = self._output_layer(pooled_representations_dropped)
        class_probabilities = F.softmax(logits, dim=-1)

        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        if label is not None:
            loss = self.loss(logits, label)
            for metric in self.metrics.values():
                metric(logits, label)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 22
0
    def forward(self, tags, history, next_sym, source_tokens, his_symptoms,
                target_tokens, **args):
        bs = len(tags)
        # self.flatten_parameters()
        # 获取history的embedding
        embeddings = self._source_embedder(history)
        mask = get_text_field_mask(history,
                                   num_wrapping_dims=1)  # num_wrapping 增加维度
        sz = list(embeddings.size())
        embeddings = embeddings.view(sz[0] * sz[1], sz[2], sz[3])
        mask = mask.view(sz[0] * sz[1], sz[2])

        # 获取每一句的hidden  bs * sen_num * hidden
        utter_hidden = self._vecoder(embeddings, mask)
        utter_hidden = utter_hidden.view(sz[0], sz[1],
                                         -1)  # bs * sen_num * hidden

        dialog_mask = get_text_field_mask(history)
        dialog_hidden = self._sen_encoder(utter_hidden, dialog_mask)  # hred的形式
        # print("dialog_hidden: ",dialog_hidden.size())

        # 初始化每个节点
        symp_state = torch.zeros(
            bs, self.symp_size,
            self.outfeature).cuda()  # bs * symp_size * hidden
        symp_state += self.symp_state  # 每一个节点的初始化emb相同,这是个问题吗?

        # 句子与句子连边  如果不用cuda呢
        sym_mat = torch.zeros(bs, self.symp_size, self.symp_size)

        for i in range(bs):
            dic = {}
            for j in range(len(tags[i])):  # 这一个地方可以改一下,这里是和前面的所有有关系的边都连上了
                symp_state[i][self.topic + j] = utter_hidden[i][j]
                dic[j] = set(list(tags[i][j]))
                for k in range(j):
                    for aa in dic[j]:
                        if aa in dic[k] and aa != -1:
                            # sym_mat[i][self.topic+j][self.topic+k] += 1
                            sym_mat[i][self.topic + k][self.topic + j] += 1

        for i in range(bs):
            for j in range(len(tags[i])):
                symp_state[i][self.topic + j] = utter_hidden[i][j]
                last = min(j + self.last_sen, len(tags[i]))
                sym_mat[i][j][j:last] = torch.ones(1)  # 和后面几条边相连

        last_h = self.attn_one(symp_state, sym_mat)
        sym_mat = torch.zeros(bs, self.symp_size, self.symp_size)
        for i in range(bs):
            for j in range(len(tags[i])):
                for tt in tags[i][j]:
                    if tt != -1:
                        sym_mat[i][self.topic + j][tt] += 1
        #
        last_h = self.attn_two(last_h, sym_mat)
        #
        # topic和topic连边
        sym_mat = torch.zeros(bs, self.symp_size, self.symp_size)

        for symp_i in his_symptoms:
            for symp_j in his_symptoms:
                self.evovl_mat[symp_i][symp_j] = 1
        temp_mat = (torch.nn.functional.relu(self.symp_mat) +
                    self.evovl_mat).cpu()
        with open('visulize_graph.txt', 'a') as fout:
            fout.write('evovl_mat is: \n')
            for i in self.evovl_mat.detach().cpu().numpy():
                fout.write(str(i) + '\n')
            fout.write('temp_mat is: \n')
            for i in temp_mat.detach().cpu().numpy():
                fout.write(str(i) + '\n')
        # print('[info] temp_mat is:{}'.format(temp_mat))
        sym_mat[:, :self.topic, :self.topic] += temp_mat

        # sym_mat[:, :self.topic, :self.topic] += self.symp_mat
        # last_h = self.attn_two(symp_state, sym_mat)
        last_h = self.attn_three(last_h, sym_mat)

        topic_pre = torch.sum(self.predict_layer * last_h,
                              dim=-1) + self.predict_bias
        topic_probs = torch.sigmoid(topic_pre)
        topics_weight = torch.ones_like(topic_probs) + 5 * next_sym.float()
        topic_loss = torch.nn.functional.binary_cross_entropy(
            topic_probs, next_sym.float(), weight=topics_weight)

        ans = (topic_probs > 0.5).long()

        # his_symptoms bs * sym_size?
        # his_mask = torch.where(his_symptoms > 0, torch.full_like(his_symptoms, 0), torch.full_like(his_symptoms,1)).long()

        # 隐藏句子节点
        # his_mask
        his_sentence_mask = torch.zeros(bs, self.sen_num).long()
        total_mask = torch.cat(
            (torch.ones(bs, self.topic).long(), his_sentence_mask), -1)

        if self.training:
            aa = next_sym.long()
        else:
            aa = ans
        # total_mask = torch.ones(bs, self.symp_size).cuda()
        # total_mask = total_mask.long() & his_mask.long()
        topic_embedding = aa.float().matmul(self.symp_state)
        topic_hidden = last_h

        # 计算topic的f1, acc, rec
        pre_total = torch.sum(ans).item()
        true_total = torch.sum(next_sym).item()
        pre_right = torch.sum((ans == next_sym).long() * next_sym).item()
        # print(pre_total,pre_right)
        self.topic_acc(pre_right, pre_total)
        self.topic_rec(pre_right, true_total)
        acc = self.topic_acc.get_metric(False)
        rec = self.topic_rec.get_metric(False)
        f1 = 0.
        if acc + rec > 0:
            f1 = acc * rec * 2 / (acc + rec)
        self.topic_f1(f1)

        # Encoding source_tokens
        # embedded_input = self._source_embedder(source_tokens)
        source_mask = util.get_text_field_mask(source_tokens)
        # encoder_outputs = self._encoder(embedded_input, source_mask)
        # final_encoder_output = util.get_final_encoder_states(encoder_outputs, source_mask, self._encoder.is_bidirectional())
        state = {
            "source_mask":
            source_mask,
            "encoder_outputs":
            topic_embedding,  # bs * seq_len * dim
            "decoder_hidden":
            torch.cat((topic_embedding, dialog_hidden),
                      -1),  # bs * dim hred的输出
            # "decoder_hidden": torch.sum(last_h * total_mask.float(), 1),
            "decoder_context":
            topic_embedding.new_zeros(bs, self._decoder_output_dim),
            # "decoder_context": topic_embedding,
            "topic_embedding":
            topic_embedding
        }
        # state[''] = topic_embedding
        # 获取一次decoder
        output_dict = self._forward_loop(state, topic_hidden,
                                         total_mask.cuda(), target_tokens)
        best_predictions = output_dict["predictions"]

        # output something
        references, hypothesis = [], []
        for i in range(bs):
            cut_hypo = best_predictions[i][:]
            if self._end_index in list(best_predictions[i]):
                cut_hypo = best_predictions[i][:list(best_predictions[i]).
                                               index(self._end_index)]
            hypothesis.append([
                self.vocab.get_token_from_index(idx.item()) for idx in cut_hypo
            ])

        flag = 1
        for i in range(bs):
            cut_ref = target_tokens['tokens'][1:]
            if self._end_index in list(target_tokens['tokens'][i]):
                cut_ref = target_tokens['tokens'][i][
                    1:list(target_tokens['tokens'][i]).index(self._end_index)]
            references.append([
                self.vocab.get_token_from_index(idx.item()) for idx in cut_ref
            ])
            if random.random() <= 0.001 and not self.training and flag == 1:
                flag = 0
                for jj in range(i):
                    print('___hypo___', ''.join(hypothesis[jj]), end=' ## ')
                    print(''.join(references[jj]))
                    print("")

        self.bleu_aver(references, hypothesis)
        self.bleu1(references, hypothesis)
        self.bleu2(references, hypothesis)
        self.bleu4(references, hypothesis)
        self.kd_metric(references, hypothesis)
        self.dink1(hypothesis)
        self.dink2(hypothesis)
        if self.training:
            output_dict['loss'] = output_dict['loss'] + 8 * topic_loss
        else:
            output_dict['loss'] = topic_loss
        return output_dict
Esempio n. 23
0
    def forward(
        self,  # type: ignore
        source: TextFieldTensors,
        **target_tokens: Dict[str, TextFieldTensors],
    ) -> Dict[str, torch.Tensor]:
        """
        Decoder logic for producing the target sequences.

        # Parameters

        source : `TextFieldTensors`
            The output of `TextField.as_array()` applied on the source
            `TextField`. This will be passed through a `TextFieldEmbedder`
            and then through an encoder.
        target_tokens : `Dict[str, TextFieldTensors]`:
            Dictionary from name to output of `Textfield.as_array()` applied
            on target `TextField`. We assume that the target tokens are also
            represented as a `TextField`.
        """
        # (batch_size, input_sequence_length, embedding_dim)
        embedded_input = self._embedding_dropout(self._source_embedder(source))
        source_mask = get_text_field_mask(source)
        # (batch_size, encoder_output_dim)
        final_encoder_output = self._encoder(embedded_input, source_mask)
        output_dict = {}

        # Perform greedy search so we can get the loss.
        if target_tokens:
            if target_tokens.keys() != self._states.keys():
                target_only = target_tokens.keys() - self._states.keys()
                states_only = self._states.keys() - target_tokens.keys()
                raise Exception(
                    "Mismatch between target_tokens and self._states. Keys in "
                    +
                    f"targets only: {target_only} Keys in states only: {states_only}"
                )
            total_loss = 0
            for name, state in self._states.items():
                loss = self.greedy_search(
                    final_encoder_output=final_encoder_output,
                    target_tokens=target_tokens[name],
                    target_embedder=state.embedder,
                    decoder_cell=state.decoder_cell,
                    output_projection_layer=state.output_projection_layer,
                )
                total_loss += loss
                output_dict[f"{name}_loss"] = loss

            # Use mean loss (instead of the sum of the losses) to be comparable to the paper.
            output_dict["loss"] = total_loss / len(self._states)

        # Perform beam search to obtain the predictions.
        if not self.training:
            batch_size = final_encoder_output.size()[0]
            for name, state in self._states.items():
                start_predictions = final_encoder_output.new_full(
                    (batch_size, ),
                    fill_value=self._start_index,
                    dtype=torch.long)
                start_state = {"decoder_hidden": final_encoder_output}

                # (batch_size, 10, num_decoding_steps)
                all_top_k_predictions, log_probabilities = self._beam_search.search(
                    start_predictions, start_state, state.take_step)

                if target_tokens:
                    self._update_recall(all_top_k_predictions,
                                        target_tokens[name], state.recall)
                output_dict[
                    f"{name}_top_k_predictions"] = all_top_k_predictions
                output_dict[
                    f"{name}_top_k_log_probabilities"] = log_probabilities

        return output_dict
Esempio n. 24
0
    def forward(
        self,  # type: ignore
        question_field: Dict[str, torch.LongTensor],
        visual_feat: torch.Tensor,
        pos: torch.Tensor,
        image_id: List[str],
        gold_question_attentions: torch.Tensor = None,
        identifier: List[str] = None,
        logical_form: List[str] = None,
        actions: List[List[ProductionRule]] = None,
        target_action_sequence: torch.LongTensor = None,
        gold_object_choices: torch.Tensor = None,
        denotation: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        batch_size, obj_num, feat_size = visual_feat.size()
        assert obj_num == 36 and feat_size == 2048
        text_masks = util.get_text_field_mask(question_field)
        (l_orig, v_orig, text, vis_only), x_orig = self._encoder(
            question_field[self._tokens_namespace], text_masks, visual_feat,
            pos)

        text_masks = text_masks.float()
        # NOTE: Taking the lxmert output before cross modality layer (which is the same for both images)
        # Can also try concatenating (dim=-1) the two encodings
        encoded_sentence = text

        initial_state = self._get_initial_state(encoded_sentence, text_masks,
                                                actions)
        initial_state.debug_info = [[] for _ in range(batch_size)]

        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        losses = []
        if (self.training or self._use_gold_program_for_eval
            ) and target_action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we
            # unsqueeze it for the MML trainer.
            search = ConstrainedBeamSearch(
                beam_size=None,
                allowed_sequences=target_action_sequence.unsqueeze(1),
                allowed_sequence_mask=target_mask.unsqueeze(1),
            )
            final_states = search.search(initial_state,
                                         self._transition_function)
            if self._training_batches_so_far < self._num_parse_only_batches:
                for batch_index in range(batch_size):
                    if not final_states[batch_index]:
                        logger.error(
                            f"No pogram found for batch index {batch_index}")
                        continue
                    losses.append(-final_states[batch_index][0].score[0])
        else:
            final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False,
            )

        action_mapping = {}
        for action_index, action in enumerate(actions[0]):
            action_mapping[action_index] = action[0]

        outputs: Dict[str, Any] = {"action_mapping": action_mapping}
        outputs["best_action_sequence"] = []
        outputs["debug_info"] = []

        if self._nmn_settings["mask_non_attention"]:
            zero_one_mult = torch.zeros_like(gold_question_attentions)
            zero_one_mult.copy_(gold_question_attentions)
            zero_one_mult[:, :, 0] = 1.0
            # sep_indices = text_masks.argmax(1).long()
            sep_indices = (
                (text_masks.long() *
                 (1 + torch.arange(text_masks.shape[1]).unsqueeze(0).repeat(
                     batch_size, 1).to(text_masks.device))).argmax(1).long())
            sep_indices = (sep_indices.unsqueeze(1).repeat(
                1, gold_question_attentions.shape[2]).unsqueeze(1).repeat(
                    1, gold_question_attentions.shape[1], 1))
            indices_dim2 = (torch.arange(
                gold_question_attentions.shape[2]).unsqueeze(0).repeat(
                    gold_question_attentions.shape[0],
                    gold_question_attentions.shape[1],
                    1,
                ).to(sep_indices.device).long())
            zero_one_mult = torch.where(
                sep_indices == indices_dim2,
                torch.ones_like(zero_one_mult),
                zero_one_mult,
            ).float()
            reshaped_questions = (
                question_field[self._tokens_namespace].unsqueeze(1).repeat(
                    1, gold_question_attentions.shape[1],
                    1).view(-1, gold_question_attentions.shape[-1]))
            reshaped_visual_feat = (visual_feat.unsqueeze(1).repeat(
                1, gold_question_attentions.shape[1], 1,
                1).view(-1, obj_num, visual_feat.shape[-1]))
            reshaped_pos = (pos.unsqueeze(1).repeat(
                1, gold_question_attentions.shape[1], 1,
                1).view(-1, obj_num, pos.shape[-1]))
            zero_one_mult = zero_one_mult.view(
                -1, gold_question_attentions.shape[-1])
            q_att_filter = zero_one_mult.sum(1) > 2
            (l_relevant, v_relevant, _, _), x_relevant = self._encoder(
                reshaped_questions[q_att_filter, :],
                zero_one_mult[q_att_filter, :],
                reshaped_visual_feat[q_att_filter, :, :],
                reshaped_pos[q_att_filter, :, :],
            )
            l = [{} for _ in range(batch_size)]
            v = [{} for _ in range(batch_size)]
            x = [{} for _ in range(batch_size)]
            count = 0
            batch_index = -1
            for i in range(zero_one_mult.shape[0]):
                module_num = i % target_action_sequence.shape[1]
                if module_num == 0:
                    batch_index += 1
                if q_att_filter[i].item():
                    l[batch_index][module_num] = l_relevant[count]
                    v[batch_index][module_num] = v_relevant[count]
                    x[batch_index][module_num] = x_relevant[count]
                    count += 1
        else:
            l = l_orig
            v = v_orig
            x = x_orig

        for batch_index in range(batch_size):
            if (self.training and self._training_batches_so_far <
                    self._num_parse_only_batches):
                continue
            if not final_states[batch_index]:
                logger.error(f"No pogram found for batch index {batch_index}")
                outputs["best_action_sequence"].append([])
                outputs["debug_info"].append([])
                continue
            world = VisualReasoningGqaLanguage(
                l[batch_index],
                v[batch_index],
                x[batch_index],
                # initial_state.rnn_state[batch_index].encoder_outputs[batch_index],
                self._language_parameters,
                pos[batch_index],
                nmn_settings=self._nmn_settings,
            )

            denotation_log_prob_list = []
            # TODO(mattg): maybe we want to limit the number of states we evaluate (programs we
            # execute) at test time, just for efficiency.
            for state_index, state in enumerate(final_states[batch_index]):
                action_indices = state.action_history[0]
                action_strings = [
                    action_mapping[action_index]
                    for action_index in action_indices
                ]
                # Shape: (num_denotations,)
                assert len(action_strings) == len(state.debug_info[0])
                # Plug in gold question attentions
                for i in range(len(state.debug_info[0])):
                    if gold_question_attentions[batch_index, i, :].sum() > 0:
                        state.debug_info[0][i]["question_attention"] = (
                            gold_question_attentions[batch_index,
                                                     i, :].float() /
                            gold_question_attentions[batch_index, i, :].sum())
                    elif self._nmn_settings["mask_non_attention"] and (
                            action_strings[i][-4:] == "find"
                            or action_strings[i][-6:] == "filter"
                            or action_strings[i][-13:] == "with_relation"):
                        state.debug_info[0][i]["question_attention"] = (
                            torch.ones_like(
                                gold_question_attentions[batch_index,
                                                         i, :]).float() /
                            gold_question_attentions[batch_index,
                                                     i, :].numel())
                        l[batch_index][i] = l_orig[batch_index]
                        v[batch_index][i] = v_orig[batch_index]
                        x[batch_index][i] = x_orig[batch_index]
                        world = VisualReasoningGqaLanguage(
                            l[batch_index],
                            v[batch_index],
                            x[batch_index],
                            # initial_state.rnn_state[batch_index].encoder_outputs[batch_index],
                            self._language_parameters,
                            pos[batch_index],
                            nmn_settings=self._nmn_settings,
                        )
                # print(action_strings)
                state_denotation_log_probs = world.execute_action_sequence(
                    action_strings, state.debug_info[0])
                # prob2 = world.execute_action_sequence(action_strings, state.debug_info[0])

                # P(denotation | parse) * P(parse | question)
                denotation_log_prob_list.append(state_denotation_log_probs)

                if not self._use_gold_program_for_eval:
                    denotation_log_prob_list[-1] += state.score[0]
                if state_index == 0:
                    outputs["best_action_sequence"].append(action_strings)
                    outputs["debug_info"].append(state.debug_info[0])
                    if target_action_sequence is not None:
                        targets = target_action_sequence[batch_index].data
                        program_correct = self._action_history_match(
                            action_indices, targets)
                        self._program_accuracy(program_correct)

            # P(denotation | parse) * P(parse | question) for the all programs on the beam.
            # Shape: (beam_size, num_denotations)
            denotation_log_probs = torch.stack(denotation_log_prob_list)
            # \Sum_parse P(denotation | parse) * P(parse | question) = P(denotation | question)
            # Shape: (num_denotations,)
            marginalized_denotation_log_probs = util.logsumexp(
                denotation_log_probs, dim=0)
            if denotation is not None:
                loss = (self.loss(
                    state_denotation_log_probs.unsqueeze(0),
                    denotation[batch_index].unsqueeze(0).float(),
                ).view(1) * self._denotation_loss_multiplier)
                losses.append(loss)
                self._denotation_accuracy(
                    torch.tensor([
                        1 - state_denotation_log_probs,
                        state_denotation_log_probs
                    ]).to(denotation.device),
                    denotation[batch_index],
                )
                if gold_object_choices is not None:
                    gold_objects = gold_object_choices[batch_index, :, :]
                    predicted_objects = torch.zeros_like(gold_objects)
                    for index in world.object_scores:
                        predicted_objects[
                            index, :] = world.object_scores[index]
                    obj_exists = gold_objects.max(1)[0] > 0
                    # Only look at modules where at least one of the proposals has the object of interest
                    predicted_objects = predicted_objects[obj_exists, :]
                    gold_objects = gold_objects[obj_exists, :]
                    gold_objects = gold_objects.view(-1)
                    predicted_objects = predicted_objects.view(-1)
                    if gold_objects.numel() > 0:
                        loss += self._obj_loss_multiplier * self.loss(
                            predicted_objects, (gold_objects.float() + 1) / 2)
                        self._proposal_accuracy(
                            torch.cat(
                                (
                                    1.0 - predicted_objects.view(-1, 1),
                                    predicted_objects.view(-1, 1),
                                ),
                                dim=-1,
                            ),
                            (gold_objects + 1) / 2,
                        )
        if losses:
            outputs["loss"] = torch.stack(losses).mean()
        if self.training:
            self._training_batches_so_far += 1
        return outputs
Esempio n. 25
0
    def forward(
        self,
        query: Dict[str, torch.LongTensor],
        docs: Dict[str, torch.LongTensor],
        dataset: List[str] = [],
        labels: Optional[Dict[str, torch.LongTensor]] = None,
        scores: Optional[Dict[str, torch.Tensor]] = None,
        relevant_ignored: Optional[torch.Tensor] = None,
        irrelevant_ignored: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        # label masks
        ls_mask = get_text_field_mask(docs)
        # (batch_size, num_docs, doc_length)
        ds_mask = get_text_field_mask(docs, num_wrapping_dims=1)
        # (batch_size, num_docs, doc_length, embedding_dim)
        ds_embedded = self.embedder(docs)
        # (batch_size, num_docs, doc_length, transform_dim)
        batch_size, num_docs, doc_length, embedding_dim = ds_embedded.shape
        # (batch_size * num_docs, doc_length, transform_dim)
        ds_embedded = ds_embedded.view(batch_size * num_docs, doc_length,
                                       embedding_dim)
        ds_mask = ds_mask.view(batch_size * num_docs, doc_length)

        if self.idf_embedder is not None:
            ds_idfs = self.idf_embedder(docs)
            ds_idfs = ds_idfs.view(batch_size * num_docs, doc_length,
                                   1).repeat(1, 1, embedding_dim)
            ds_embedded = ds_embedded * ds_idfs

        # (batch_size, query_length)
        qs_mask = get_text_field_mask(query)
        _, query_length = qs_mask.shape
        qs_mask = qs_mask.unsqueeze(1).repeat(1, num_docs,
                                              1).view(batch_size * num_docs,
                                                      -1)
        # (batch_size, query_length, embedding_dim)
        qs_embedded = self.embedder(query)
        # (batch_size, num_docs, query_length, embedding_dim)
        qs_embedded = qs_embedded.unsqueeze(1).repeat(1, num_docs, 1, 1)
        # (batch_size, num_docs, query_length, embedding_dim)
        qs_embedded = qs_embedded.view(batch_size * num_docs, query_length,
                                       embedding_dim)

        if self.idf_embedder is not None:
            qs_idfs = self.idf_embedder(query).unsqueeze(1).repeat(
                1, num_docs, 1, 1)
            qs_idfs = qs_idfs.view(batch_size * num_docs, query_length,
                                   1).repeat(1, 1, embedding_dim)
            qs_embedded = qs_embedded * qs_idfs

        logits = self.scorer(qs_embedded, ds_embedded, qs_mask, ds_mask)
        #logits = F.log_softmax(logits,dim=1)

        scores = torch.exp(scores / self.temperature).view(
            batch_size * num_docs, -1)
        logits = torch.cat([logits, scores], dim=1)

        logits = self.final_scorer(logits).view(batch_size, num_docs)

        output_dict = {'logits': logits}

        if labels is not None:
            # filter out to only the metrics we care about
            # if self.training:
            #     if self.ranking_loss:
            #         loss = self.loss(logits[:, 0], logits[:, 1], labels.float()*-2.+1.)
            #     else:
            #         loss = self.loss(logits, labels.squeeze(-1).long())

            #self.metrics['accuracy'](logits, labels.squeeze(-1))
            # else:
            #    # at validation time, we can't compute a proper loss
            #    loss = torch.Tensor([0.])
            #    for metric in self.training_metrics[False]:
            #        self.metrics[metric](logits, labels.squeeze(-1).long(), ls_mask, relevant_ignored, irrelevant_ignored)
            #output_dict['loss'] = self.loss(logits[:, 0], logits[:, 1], labels.float()*2.+1.)
            output_dict['loss'] = self.loss(logits, labels)

        if labels is not None:
            sfl = F.log_softmax(logits, dim=1)
            #print(sfl)
            output_dict['accuracy'] = self.metrics['accuracy'](sfl, labels)

        return output_dict
    def forward(
            self,  # type: ignore
            user_utterance: Dict[str, torch.LongTensor],
            prev_user_utterance: Dict[str, torch.LongTensor],
            prev_sys_utterance: Dict[str, torch.LongTensor],
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        user_utterance : Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        prev_user_utterance : Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        prev_sys_utterance : Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        label : Variable, optional (default = None)
            A variable representing the intent label for each instance in the batch.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the
            label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_user_utterance = self.text_field_embedder(user_utterance)
        user_utterance_mask = util.get_text_field_mask(user_utterance)
        encoded_user_utterance = self.user_utterance_encoder(
            embedded_user_utterance, user_utterance_mask)

        embedded_prev_user_utterance = self.text_field_embedder(
            prev_user_utterance)
        prev_user_utterance_mask = util.get_text_field_mask(
            prev_user_utterance)
        encoded_prev_user_utterance = self.prev_user_utterance_encoder(
            embedded_prev_user_utterance, prev_user_utterance_mask)

        embedded_prev_sys_utterance = self.text_field_embedder(
            prev_sys_utterance)
        prev_sys_utterance_mask = util.get_text_field_mask(prev_sys_utterance)
        encoded_prev_sys_utterance = self.prev_sys_utterance_encoder(
            embedded_prev_sys_utterance, prev_sys_utterance_mask)

        logits = self.classifier_feedforward(
            torch.cat([
                encoded_user_utterance, encoded_prev_user_utterance,
                encoded_prev_sys_utterance
            ],
                      dim=-1))
        class_probs = F.softmax(logits, dim=1)
        output_dict = {'logits': logits}
        if label is not None:
            loss = self.loss(logits, label)
            output_dict["loss"] = loss

            # compute F1 per label
            for i in range(self.num_classes):
                metric = self.label_f1_metrics[self.vocab.get_token_from_index(
                    index=i, namespace="labels")]
                metric(class_probs, label)
            self.label_accuracy(logits, label)

        return output_dict
Esempio n. 27
0
    def forward(self, text, spans, ner_labels, coref_labels, relation_labels,
                trigger_labels, argument_labels, metadata):
        """
        TODO(dwadden) change this.
        """
        # For co-training on Ontonotes, need to change the loss weights depending on the data coming
        # in. This is a hack but it will do for now.
        if self._co_train:
            if self.training:
                dataset = [entry["dataset"] for entry in metadata]
                assert len(set(dataset)) == 1
                dataset = dataset[0]
                assert dataset in ["ace", "ontonotes"]
                if dataset == "ontonotes":
                    self._loss_weights = dict(coref=1,
                                              ner=0,
                                              relation=0,
                                              events=0)
                else:
                    self._loss_weights = self._permanent_loss_weights
            # This assumes that there won't be any co-training data in the dev and test sets, and that
            # coref propagation will still happen even when the coref weight is set to 0.
            else:
                self._loss_weights = self._permanent_loss_weights

        # In AllenNLP, AdjacencyFields are passed in as floats. This fixes it.
        relation_labels = relation_labels.long()
        argument_labels = argument_labels.long()

        # If we're doing Bert, get the sentence class token as part of the text embedding. This will
        # break if we use Bert together with other embeddings, but that won't happen much.
        if "bert-offsets" in text:
            # NOTE(dwadden) This operation mutates the text. We clone it so that the input isn't
            # mutated; otherwise, successive `forward` calls on the same data variable would give
            # different results because the data got mutated silently.
            new_text = {}
            for k, v in text.items():
                new_text[k] = v.clone()
            text = new_text
            offsets = text["bert-offsets"]
            sent_ix = torch.zeros(offsets.size(0),
                                  device=offsets.device,
                                  dtype=torch.long).unsqueeze(1)
            padded_offsets = torch.cat([sent_ix, offsets], dim=1)
            text["bert-offsets"] = padded_offsets
            padded_embeddings = self._text_field_embedder(text)
            cls_embeddings = padded_embeddings[:, 0, :]
            text_embeddings = padded_embeddings[:, 1:, :]
        else:
            text_embeddings = self._text_field_embedder(text)
            cls_embeddings = torch.zeros(
                [text_embeddings.size(0),
                 text_embeddings.size(2)],
                device=text_embeddings.device)

        text_embeddings = self._lexical_dropout(text_embeddings)

        # Shape: (batch_size, max_sentence_length)
        text_mask = util.get_text_field_mask(text).float()
        sentence_group_lengths = text_mask.sum(dim=1).long()

        sentence_lengths = 0 * text_mask.sum(dim=1).long()
        for i in range(len(metadata)):
            sentence_lengths[
                i] = metadata[i]["end_ix"] - metadata[i]["start_ix"]
            for k in range(sentence_lengths[i], sentence_group_lengths[i]):
                text_mask[i][k] = 0

        max_sentence_length = sentence_lengths.max().item()

        # TODO(Ulme) Speed this up by tensorizing
        new_text_embeddings = torch.zeros([
            text_embeddings.shape[0], max_sentence_length,
            text_embeddings.shape[2]
        ],
                                          device=text_embeddings.device)
        for i in range(len(new_text_embeddings)):
            new_text_embeddings[
                i][0:metadata[i]["end_ix"] -
                   metadata[i]["start_ix"]] = text_embeddings[i][
                       metadata[i]["start_ix"]:metadata[i]["end_ix"]]

        #max_sent_len = max(sentence_lengths)
        #the_list = [list(k+metadata[i]["start_ix"] if k < max_sent_len else 0 for k in range(text_embeddings.shape[1])) for i in range(len(metadata))]
        #import ipdb; ipdb.set_trace()
        #text_embeddings = torch.gather(text_embeddings, 1, torch.tensor(the_list, device=text_embeddings.device).unsqueeze(2).repeat(1, 1, text_embeddings.shape[2]))
        text_embeddings = new_text_embeddings

        # Only keep the text embeddings that correspond to actual tokens.
        # text_embeddings = text_embeddings[:, :max_sentence_length, :].contiguous()
        text_mask = text_mask[:, :max_sentence_length].contiguous()

        # Shape: (batch_size, max_sentence_length, encoding_dim)
        contextualized_embeddings = self._lstm_dropout(
            self._context_layer(text_embeddings, text_mask))
        assert spans.max() < contextualized_embeddings.shape[1]

        if self._attentive_span_extractor is not None:
            # Shape: (batch_size, num_spans, emebedding_size)
            attended_span_embeddings = self._attentive_span_extractor(
                text_embeddings, spans)

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)

        if self._attentive_span_extractor is not None:
            # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
            span_embeddings = torch.cat(
                [endpoint_span_embeddings, attended_span_embeddings], -1)
        else:
            span_embeddings = endpoint_span_embeddings

        # TODO(Ulme) try normalizing span embeddeings
        #span_embeddings = span_embeddings.abs().sum(dim=-1).unsqueeze(-1)

        # Make calls out to the modules to get results.
        output_coref = {'loss': 0}
        output_ner = {'loss': 0}
        output_relation = {'loss': 0}
        output_events = {'loss': 0}

        # Prune and compute span representations for coreference module
        if self._loss_weights["coref"] > 0 or self._coref.coref_prop > 0:
            output_coref, coref_indices = self._coref.compute_representations(
                spans, span_mask, span_embeddings, sentence_lengths,
                coref_labels, metadata)

        # Prune and compute span representations for relation module
        if self._loss_weights["relation"] > 0 or self._relation.rel_prop > 0:
            output_relation = self._relation.compute_representations(
                spans, span_mask, span_embeddings, sentence_lengths,
                relation_labels, metadata)

        # Propagation of global information to enhance the span embeddings
        if self._coref.coref_prop > 0:
            # TODO(Ulme) Implement Coref Propagation
            output_coref = self._coref.coref_propagation(output_coref)
            span_embeddings = self._coref.update_spans(output_coref,
                                                       span_embeddings,
                                                       coref_indices)

        if self._relation.rel_prop > 0:
            output_relation = self._relation.relation_propagation(
                output_relation)
            span_embeddings = self.update_span_embeddings(
                span_embeddings, span_mask,
                output_relation["top_span_embeddings"],
                output_relation["top_span_mask"],
                output_relation["top_span_indices"])

        # Make predictions and compute losses for each module
        if self._loss_weights['ner'] > 0:
            output_ner = self._ner(spans, span_mask, span_embeddings,
                                   sentence_lengths, ner_labels, metadata)

        if self._loss_weights['coref'] > 0:
            output_coref = self._coref.predict_labels(output_coref, metadata)

        if self._loss_weights['relation'] > 0:
            output_relation = self._relation.predict_labels(
                relation_labels, output_relation, metadata)

        if self._loss_weights['events'] > 0:
            # Make the trigger embeddings the same size as the argument embeddings to make
            # propagation easier.
            if self._events._span_prop._n_span_prop > 0:
                trigger_embeddings = contextualized_embeddings.repeat(1, 1, 2)
                trigger_widths = torch.zeros(
                    [trigger_embeddings.size(0),
                     trigger_embeddings.size(1)],
                    device=trigger_embeddings.device,
                    dtype=torch.long)
                trigger_width_embs = self._endpoint_span_extractor._span_width_embedding(
                    trigger_widths)
                trigger_width_embs = trigger_width_embs.detach()
                trigger_embeddings = torch.cat(
                    [trigger_embeddings, trigger_width_embs], dim=-1)
            else:
                trigger_embeddings = contextualized_embeddings

            output_events = self._events(text_mask, trigger_embeddings, spans,
                                         span_mask, span_embeddings,
                                         cls_embeddings, sentence_lengths,
                                         output_ner, trigger_labels,
                                         argument_labels, ner_labels, metadata)

        if "loss" not in output_coref:
            output_coref["loss"] = 0
        if "loss" not in output_relation:
            output_relation["loss"] = 0

        # TODO(dwadden) just did this part.
        loss = (self._loss_weights['coref'] * output_coref['loss'] +
                self._loss_weights['ner'] * output_ner['loss'] +
                self._loss_weights['relation'] * output_relation['loss'] +
                self._loss_weights['events'] * output_events['loss'])

        output_dict = dict(coref=output_coref,
                           relation=output_relation,
                           ner=output_ner,
                           events=output_events)
        output_dict['loss'] = loss

        # Check to see if event predictions are globally compatible (argument labels are compatible
        # with NER tags and trigger tags).
        # if self._loss_weights["ner"] > 0 and self._loss_weights["events"] > 0:
        #     decoded_ner = self._ner.decode(output_dict["ner"])
        #     decoded_events = self._events.decode(output_dict["events"])
        #     self._joint_metrics(decoded_ner, decoded_events)

        return output_dict
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``.
        label : torch.LongTensor, optional (default = None)
            A variable representing the label for each instance in the batch.
        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a
            distribution over the label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        text_mask = util.get_text_field_mask(tokens).float()
        # Pop elmo tokens, since elmo embedder should not be present.
        elmo_tokens = tokens.pop("elmo", None)
        embedded_text = self._text_field_embedder(tokens)

        # Add the "elmo" key back to "tokens" if not None, since the tests and the
        # subsequent training epochs rely not being modified during forward()
        if elmo_tokens is not None:
            tokens["elmo"] = elmo_tokens

        # Create ELMo embeddings if applicable
        if self._elmo:
            if elmo_tokens is not None:
                elmo_representations = self._elmo(elmo_tokens)["elmo_representations"]
                # Pop from the end is more performant with list
                if self._use_integrator_output_elmo:
                    integrator_output_elmo = elmo_representations.pop()
                if self._use_input_elmo:
                    input_elmo = elmo_representations.pop()
                assert not elmo_representations
            else:
                raise ConfigurationError(
                        "Model was built to use Elmo, but input text is not tokenized for Elmo.")

        if self._use_input_elmo:
            embedded_text = torch.cat([embedded_text, input_elmo], dim=-1)

        dropped_embedded_text = self._embedding_dropout(embedded_text)
        pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text)
        encoded_tokens = self._encoder(pre_encoded_text, text_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.last_dim_softmax(attention_logits, text_mask)
        encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([encoded_tokens,
                                      encoded_tokens - encoded_text,
                                      encoded_tokens * encoded_text], 2)
        integrated_encodings = self._integrator(integrator_input, text_mask)

        # Concatenate ELMo representations to integrated_encodings if specified
        if self._use_integrator_output_elmo:
            integrated_encodings = torch.cat([integrated_encodings,
                                              integrator_output_elmo], dim=-1)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
                integrated_encodings, text_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
                integrated_encodings, text_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(text_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
                integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits, text_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights)

        pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self._integrator_dropout(pooled_representations)

        logits = self._output_layer(pooled_representations_dropped)
        class_probabilities = F.softmax(logits, dim=-1)

        output_dict = {'logits': logits, 'class_probabilities': class_probabilities}
        if label is not None:
            loss = self.loss(logits, label)
            for metric in self.metrics.values():
                metric(logits, label)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 29
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                valid_actions: List[List[ProductionRule]],
                action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        """
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        batch_size = embedded_utterance.size(0)

        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask))
        initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            loss_output = self._decoder_trainer.decode(initial_state,
                                                       self._transition_function,
                                                       (action_sequence.unsqueeze(1),
                                                        target_mask.unsqueeze(1)))
            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(self._max_decoding_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[i][action_index]
                                  for action_index in best_action_indices]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                outputs['best_action_sequence'].append(action_strings)
                outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs
    def forward(
            self,  # type: ignore
            premise: Dict[str, torch.LongTensor],
            hypothesis: Dict[str, torch.LongTensor],
            label: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise,
                                                     premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(
                embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise,
                                                   projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(
            similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat(
            [embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat(
            [embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(
            hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(
            -1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis],
                                    dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {
            "label_logits": label_logits,
            "label_probs": label_probs
        }

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label.squeeze(-1))
            output_dict["loss"] = loss

        return output_dict
    def forward(
            self,  # type: ignore
            sentence1: Dict[str, torch.LongTensor],
            span_field1: torch.LongTensor,
            span1_text: Dict[str, torch.LongTensor],
            sentence2: Dict[str, torch.LongTensor],
            span_field2: torch.LongTensor,
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        sentence1 : Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        span_field1 : torch.LongTensor, required
            The span field for sentence 1
        span1_text : Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        sentence2: Dict[str, Variable], required
            The output of ``TextField.as_array()``.
        span_field2 : torch.LongTensor, required
            The span field for sentence 2
        label : Variable, optional (default = None)
            A variable representing the label for each instance in the batch.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : torch.FloatTensor
            A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the
            label classes for each instance.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_sentence1 = self.text_field_embedder(sentence1)
        embedded_sentence2 = self.text_field_embedder(sentence2)

        # Encode the sequence
        if self.seq2seq_encoder:
            sentence1_mask = util.get_text_field_mask(sentence1)
            encoded_sentence1 = self.seq2seq_encoder(embedded_sentence1,
                                                     sentence1_mask).data
            sentence2_mask = util.get_text_field_mask(sentence2)
            encoded_sentence2 = self.seq2seq_encoder(embedded_sentence2,
                                                     sentence2_mask).data

        # Using an embedder that returns a vector for each token.
        # We take the span from it.
        else:
            encoded_sentence1 = embedded_sentence1
            encoded_sentence2 = embedded_sentence2

        # Extract the span
        span1 = self.span_extractor(encoded_sentence1, span_field1)
        span2 = self.span_extractor(encoded_sentence2, span_field2)

        # Endpoint extractor should return [2, batch_size, emb_dim].
        # For a single word span, concat a zero vector to keep the lengths.
        if span2.size() != span1.size():
            span2 = torch.cat(
                [encoded_sentence2,
                 torch.zeros_like(encoded_sentence2)],
                dim=-1)

        input = torch.cat([span1, span2], dim=-1).squeeze(0)
        logits = self.classifier_feedforward(input)
        output_dict = {'logits': logits}
        if label is not None:
            loss = self.loss(logits, label)
            for metric in self.metrics.values():
                metric(logits, label)
            output_dict["loss"] = loss

        return output_dict
    def forward(self,  # type: ignore
                sentences: torch.LongTensor,
                labels: torch.IntTensor = None,
                confidences: torch.Tensor = None,
                additional_features: torch.Tensor = None,
                ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences

        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = embedded_sentences.size()

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            sentences_mask = sentences['bert'] == 103  # mask for all the SEP tokens in the batch
            embedded_sentences = embedded_sentences[sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
                                                                        # returns num_sentences_per_batch x vector_len
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(dim=0)
            embedded_sentences = self.dropout(embedded_sentences)

            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(f'Found {num_labels} labels but {num_sentences} sentences')
                    labels = labels[:num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                                                        # We are ignoring this problem for now.
                                                        # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = embedded_sentences.size()
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat((embedded_sentences, additional_features), dim=-1)

        label_logits = self.time_distributed_aggregate_feedforward(embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            flattened_logits = label_logits.view((batch_size * num_sentences), self.num_labels)
            flattened_gold = labels.contiguous().view(-1)

            if not self.with_crf:
                label_loss = self.loss(flattened_logits.squeeze(), flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(label_loss).view(-1)
                label_loss = label_loss.mean()
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels, mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view((batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                evaluation_mask = (flattened_gold != -1)
                self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs, flattened_gold, mask=evaluation_mask)
        
        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict
Esempio n. 33
0
    def forward(
            self,  # type: ignore
            question,
            passage,
            span_start=None,
            span_end=None,
            metadata=None):
        ######## 1/2. Embedding Layer ########
        # 2. After add embedding, pass the embedded vector to Highway network.
        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))

        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        ######## 3. Contextual Embedding Layer ########
        # Encode input vectors into new representation H and U by using a BiLSTM.
        # Shape: (batch_size, 2 * encoding dim, question_length)
        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        # Shape: (batch_size, 2 * encoding dim, paragraph_length)
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        # get each token dim by accessing the last dimention.
        encoding_dim = encoded_question.size(-1)

        ######## 4. Attention Flow Layer ########
        # Calculate similarity matrix for attention layer.
        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)

        # Create weighted vectors attended by context to query attention.
        # Calculate C2Q(context to query) attention, 'a'
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Weighted vector by C2Q attentions. \hat{U}_:t \sum_j a_{tj} U_{:j}
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Create weighted vectors attended by query to context attention.
        # Replaced masked values to avoid let them affect the result.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        # Calculate Q2C(query to context) attention, 'b'
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0]
        # Pass to the softmax layer.
        # Shape: (batch_size, passage_length)
        question_passage_attention = \
            util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Merge attention vectors
        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        # Add purple "linear ReLU layer"
        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))
        # Bi-GRU in the paper
        residual_layer = self._dropout(
            self._residual_encoder(self._dropout(final_merged_passage),
                                   passage_mask))
        self_atten_matrix = self._self_atten(residual_layer, residual_layer)

        # Expand mask for self-attention
        mask = (passage_mask.resize(batch_size, passage_length, 1) *
                passage_mask.resize(batch_size, 1, passage_length))

        # Mask should have zeros on the diagonal.
        # torch.eye does not have a gpu implementation, so we are forced to use
        # the cpu one and .cuda(). Not sure if this matters for performance.
        eye = torch.eye(passage_length, passage_length)
        if mask.is_cuda:
            eye = eye.cuda()
        self_mask = Variable(eye).resize(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_atten_probs = util.last_dim_softmax(self_atten_matrix, mask)

        # Batch matrix multiplication:
        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_atten_vecs = torch.matmul(self_atten_probs, residual_layer)

        # (extended_batch_size, passage_length, embedding_dim * 3)
        concatenated = torch.cat([
            self_atten_vecs, residual_layer, residual_layer * self_atten_vecs
        ],
                                 dim=-1)
        # _merge_self_atten => (extended_batch_size, passage_length,
        # embedding_dim)
        residual_layer = F.relu(self._merge_self_atten(concatenated))

        # print("residual", residual_layer.size())

        final_merged_passage += residual_layer
        final_merged_passage = self._dropout(final_merged_passage)

        # Bi-GRU in paper
        start_rep = self._span_start_encoder(final_merged_passage,
                                             passage_lstm_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        end_rep = self._span_end_encoder(
            torch.cat([final_merged_passage, start_rep], dim=-1),
            passage_lstm_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        if span_start is not None:
            # Calculate the loss.
            # The training loss is the sum of the negative log probablities of
            # the true start and end indices by the predicted distributions,
            # everaged over all examples.
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            # Why need to be `torch.stack`
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Esempio n. 34
0
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            span_starts: torch.IntTensor,
            span_ends: torch.IntTensor,
            span_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        span_starts : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 1), representing the start indices of
            candidate spans for mentions. Comes from a ``ListField[IndexField]`` of indices into
            the text of the document.
        span_ends : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 1), representing the end indices of
            candidate spans for mentions. Comes from a ``ListField[IndexField]`` of indices into
            the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = span_starts.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans, 1)
        span_mask = (span_starts >= 0).float()
        # IndexFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        span_starts = F.relu(span_starts.float()).long()
        span_ends = F.relu(span_ends.float()).long()

        # Shape: (batch_size, num_spans, embedding_size)
        span_embeddings = self._compute_span_representations(
            text_embeddings, text_mask, span_starts, span_ends)
        # Compute a score for whether each span is a mention,
        # making sure that masked spans have very low scores.
        # Shape: (batch_size, num_spans, 1)
        mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings))
        mention_scores += span_mask.log()

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))

        # Shape: (batch_size, num_spans_to_keep)
        # These are indices (with values between 0 and num_spans) into
        # the span_embeddings tensor.
        top_span_indices = self._prune_and_sort_spans(mention_scores,
                                                      num_spans_to_keep)

        # Now that we've decided which spans are actually mentions the next
        # few steps are reformatting all of our variables to be in terms of
        # num_spans_to_keep instead of num_spans, so we don't waste computation
        # on spans that we've already discarded.

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Select the span embeddings corresponding to the
        # top spans based on the mention scorer.
        # Shape: (batch_size, num_spans_to_keep, embedding_size)
        top_span_embeddings = util.batched_index_select(
            span_embeddings, top_span_indices, flat_top_span_indices)
        # Shape: (batch_size, num_spans_to_keep, 1)
        # TODO(Mark): If we parameterised the mention scorer to score things in (0, inf)
        # I think we could get rid of the need for this mask entirely.
        top_span_mask = util.batched_index_select(span_mask, top_span_indices,
                                                  flat_top_span_indices)
        top_span_mention_scores = util.batched_index_select(
            mention_scores, top_span_indices, flat_top_span_indices)
        top_span_starts = util.batched_index_select(span_starts,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_span_ends = util.batched_index_select(span_ends, top_span_indices,
                                                  flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, text_mask.is_cuda)
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)

        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)
        # Compute final predictions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = torch.cat([top_span_starts, top_span_ends], -1)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.last_dim_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood
        return output_dict
Esempio n. 35
0
    def _forward_loop(
        self,
        state: Dict[str, torch.Tensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.

        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]
        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),
                                                fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(
                input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 36
0
    def forward(self,
                text,
                spans,
                metadata,
                ner_labels=None,
                coref_labels=None,
                relation_labels=None,
                trigger_labels=None,
                argument_labels=None):
        """
        TODO(dwadden) change this.
        """
        # In AllenNLP, AdjacencyFields are passed in as floats. This fixes it.
        if relation_labels is not None:
            relation_labels = relation_labels.long()
        if argument_labels is not None:
            argument_labels = argument_labels.long()

        # TODO(dwadden) Multi-document minibatching isn't supported yet. For now, get rid of the
        # extra dimension in the input tensors. Will return to this once the model runs.
        if len(metadata) > 1:
            raise NotImplementedError("Multi-document minibatching not yet supported.")

        metadata = metadata[0]
        spans = self._debatch(spans)  # (n_sents, max_n_spans, 2)
        ner_labels = self._debatch(ner_labels)  # (n_sents, max_n_spans)
        coref_labels = self._debatch(coref_labels)  #  (n_sents, max_n_spans)
        relation_labels = self._debatch(relation_labels)  # (n_sents, max_n_spans, max_n_spans)
        trigger_labels = self._debatch(trigger_labels)  # TODO(dwadden)
        argument_labels = self._debatch(argument_labels)  # TODO(dwadden)

        # Encode using BERT, then debatch.
        # Since the data are batched, we use `num_wrapping_dims=1` to unwrap the document dimension.
        # (1, n_sents, max_sententence_length, embedding_dim)

        # TODO(dwadden) Deal with the case where the input is longer than 512.
        text_embeddings = self._embedder(text, num_wrapping_dims=1)
        # (n_sents, max_n_wordpieces, embedding_dim)
        text_embeddings = self._debatch(text_embeddings)

        # (n_sents, max_sentence_length)
        text_mask = self._debatch(util.get_text_field_mask(text, num_wrapping_dims=1).float())
        sentence_lengths = text_mask.sum(dim=1).long()  # (n_sents)

        span_mask = (spans[:, :, 0] >= 0).float()  # (n_sents, max_n_spans)
        # SpanFields return -1 when they are used as padding. As we do some comparisons based on
        # span widths when we attend over the span representations that we generate from these
        # indices, we need them to be <= 0. This is only relevant in edge cases where the number of
        # spans we consider after the pruning stage is >= the total number of spans, because in this
        # case, it is possible we might consider a masked span.
        spans = F.relu(spans.float()).long()  # (n_sents, max_n_spans, 2)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        span_embeddings = self._endpoint_span_extractor(text_embeddings, spans)

        # Make calls out to the modules to get results.
        output_coref = {'loss': 0}
        output_ner = {'loss': 0}
        output_relation = {'loss': 0}
        output_events = {'loss': 0}

        # Prune and compute span representations for coreference module
        if self._loss_weights["coref"] > 0 or self._coref.coref_prop > 0:
            output_coref, coref_indices = self._coref.compute_representations(
                spans, span_mask, span_embeddings, sentence_lengths, coref_labels, metadata)

        # Propagation of global information to enhance the span embeddings
        if self._coref.coref_prop > 0:
            output_coref = self._coref.coref_propagation(output_coref)
            span_embeddings = self._coref.update_spans(
                output_coref, span_embeddings, coref_indices)

        # Make predictions and compute losses for each module
        if self._loss_weights['ner'] > 0:
            output_ner = self._ner(
                spans, span_mask, span_embeddings, sentence_lengths, ner_labels, metadata)

        if self._loss_weights['coref'] > 0:
            output_coref = self._coref.predict_labels(output_coref, metadata)

        if self._loss_weights['relation'] > 0:
            output_relation = self._relation(
                spans, span_mask, span_embeddings, sentence_lengths, relation_labels, metadata)

        if self._loss_weights['events'] > 0:
            # The `text_embeddings` serve as representations for event triggers.
            output_events = self._events(
                text_mask, text_embeddings, spans, span_mask, span_embeddings,
                sentence_lengths, trigger_labels, argument_labels,
                ner_labels, metadata)

        # Use `get` since there are some cases where the output dict won't have a loss - for
        # instance, when doing prediction.
        loss = (self._loss_weights['coref'] * output_coref.get("loss", 0) +
                self._loss_weights['ner'] * output_ner.get("loss", 0) +
                self._loss_weights['relation'] * output_relation.get("loss", 0) +
                self._loss_weights['events'] * output_events.get("loss", 0))

        # Multiply the loss by the weight multiplier for this document.
        weight = metadata.weight if metadata.weight is not None else 1.0
        loss *= torch.tensor(weight)

        output_dict = dict(coref=output_coref,
                           relation=output_relation,
                           ner=output_ner,
                           events=output_events)
        output_dict['loss'] = loss

        output_dict["metadata"] = metadata

        return output_dict
Esempio n. 37
0
    def forward(self,  # type: ignore
                source_tokens: Dict[str, torch.LongTensor],
                target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_tokens : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be
           passed through a ``TextFieldEmbedder`` and then through an encoder.
        target_tokens : Dict[str, torch.LongTensor], optional (default = None)
           Output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the
           target tokens are also represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, encoder_output_dim)
        embedded_input = self._source_embedder(source_tokens)
        batch_size, _, _ = embedded_input.size()
        source_mask = get_text_field_mask(source_tokens)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        final_encoder_output = encoder_outputs[:, -1]  # (batch_size, encoder_output_dim)
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end symbol. Either way, we
            # don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        decoder_context = encoder_outputs.new_zeros(batch_size, self._decoder_output_dim)
        last_predictions = None
        step_logits = []
        step_probabilities = []
        step_predictions = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(1).item() >= self._scheduled_sampling_ratio:
                input_choices = targets[:, timestep]
            else:
                if timestep == 0:
                    # For the first timestep, when we do not have targets, we input start symbols.
                    # (batch_size,)
                    input_choices = source_mask.new_full((batch_size,), fill_value=self._start_index)
                else:
                    input_choices = last_predictions
            decoder_input = self._prepare_decode_step_input(input_choices, decoder_hidden,
                                                            encoder_outputs, source_mask)
            decoder_hidden, decoder_context = self._decoder_cell(decoder_input,
                                                                 (decoder_hidden, decoder_context))
            # (batch_size, num_classes)
            output_projections = self._output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            step_probabilities.append(class_probabilities.unsqueeze(1))
            last_predictions = predicted_classes
            # (batch_size, 1)
            step_predictions.append(last_predictions.unsqueeze(1))
        # step_logits is a list containing tensors of shape (batch_size, 1, num_classes)
        # This is (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        class_probabilities = torch.cat(step_probabilities, 1)
        all_predictions = torch.cat(step_predictions, 1)
        output_dict = {"logits": logits,
                       "class_probabilities": class_probabilities,
                       "predictions": all_predictions}
        if target_tokens:
            target_mask = get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss
            # TODO: Define metrics
        return output_dict
Esempio n. 38
0
    def forward(self,
                question,
                passage,
                span_start=None,
                span_end=None,
                metadata=None):
        ######## 1/2. Embedding Layer ########
        # 2. After add embedding, pass the embedded vector to Highway network.
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))

        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        ######## 3. Contextual Embedding Layer ########
        # Encode input vectors into new representation H and U by using a BiLSTM.
        # Shape: (batch_size, 2 * encoding dim, question_length)
        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        # Shape: (batch_size, 2 * encoding dim, paragraph_length)
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        # get each token dim by accessing the last dimention.
        encoding_dim = encoded_question.size(-1)

        ######## 4. Attention Flow Layer ########
        # Calculate similarity matrix for attention layer.
        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)

        # Create weighted vectors attended by context to query attention.
        # Calculate C2Q(context to query) attention, 'a'
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Weighted vector by C2Q attentions. \hat{U}_:t \sum_j a_{tj} U_{:j}
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Create weighted vectors attended by query to context attention.
        # Replaced masked values to avoid let them affect the result.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)

        # Calculate Q2C(query to context) attention, 'b'
        # Shape: (batch_size, passage_length)
        question_passage_similarity = \
            masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Pass to the softmax layer.
        # Shape: (batch_size, passage_length)
        question_passage_attention = \
            util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Merge attention vectors
        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        ######## 5. Modeling Layer ########
        # Model query-aware context vector by BiLSTM.
        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        ######## 6. Output Layer ########
        # Obtain the probability distribution of the start index.
        # Concat G(from attention flow layer) and M(from modeling layer)
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Calculate the logits.
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Calculate Softmax (eq. 3)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Obtain the probability distribution of the end index.
        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim *
        # 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Obtain new modling representation based on start probalility.
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Concat attention flow output G and new modeling representation M2
        # Shape: (batch_size, passage_length, encoding_dim * 4 +
        # span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        # Shape: (batch_size, passage_length)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)

        # Replace the masked values pretty small not to influence the final
        # results.
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)

        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        if span_start is not None:
            # Calculate the loss.
            # The training loss is the sum of the negative log probablities of
            # the true start and end indices by the predicted distributions,
            # everaged over all examples.
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            # Why need to be `torch.stack`
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].data.cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Esempio n. 39
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(self._text_field_embedder(question))
        embedded_passage = self._highway_layer(self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size,
                                                                                    passage_length,
                                                                                    encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([encoded_passage,
                                          passage_question_vectors,
                                          encoded_passage * passage_question_vectors,
                                          encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage, span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size,
                                                                                   passage_length,
                                                                                   modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([final_merged_passage,
                                             modeled_passage,
                                             tiled_start_representation,
                                             modeled_passage * tiled_start_representation],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation,
                                                                passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
        best_span = self.get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Esempio n. 40
0
    def forward(self, tokens: Dict[str, torch.Tensor],
                id: Any, answerers: Any, date: Any, accept_usr: Any) -> torch.Tensor:

        # n -- batch number
        # m -- author number
        # d -- hidden dimension
        # k -- skill number
        # l -- text length
        # p -- pos/neg author number in one batch
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)  # (n, l, d)
        token_hidden = self.encoder(embeddings, mask).transpose(-1, -2)  # (n, l, d)

        token_embed = torch.mean(token_hidden, 1).squeeze(1)  # (n, d)
        # token_embed = token_hidden[:, :, -1]
        author_ctx_embed = self.ctx_attention(token_embed, self.author_embeddings, self.author_embeddings)  # (n, m, d)

        # add layer norm for author context embedding
        author_ctx_embed = self.ctx_layer_norm(author_ctx_embed)  # (n, m, d)

        # transfer the date into time embedding
        # TODO: use answer date for time embedding
        time_embed = gen_time_encoding(self.time_encoder, answerers, date, embeddings.size(2), self.num_authors, train_mode=self.training)
        # time_embed = [self.time_encoder.get_time_encoding(i) for i in date]
        # time_embed = torch.stack(time_embed, dim=0)  # (n, d)
        # time_embed = time_embed.unsqueeze(1).expand(-1, self.num_authors, -1)  # (n, m, d)

        author_ctx_embed_te = author_ctx_embed + time_embed
        author_tctx_embed = self.temp_ctx_attention(time_embed, author_ctx_embed, author_ctx_embed)  # (n, m, d)
        # author_tctx_embed = self.temp_ctx_attention(author_ctx_embed_te, author_ctx_embed_te, author_ctx_embed_te)  # (n, m, d)

        # get horizontal temporal time embeddings
        # htemp_embeds = []
        # truth = [[j[0] for j in i] for i in answerers]
        # for i, d in enumerate(date):
        #     pos_labels = br_utils.to_cuda(torch.tensor(truth[i]))
        #     post_time_embeds = self.time_encoder.get_post_encodings(d)  # (t, d)
        #     post_time_embeds = post_time_embeds.unsqueeze(1).expand(-1, pos_labels.size(0), -1)  # (t, pos, d)
        #
        #     pos_embed = author_ctx_embed[i, pos_labels, :]  # (pos, d)
        #     pos_embed = pos_embed.unsqueeze(0).expand(post_time_embeds.size(0), -1, -1)  # (t, pos, d)
        #     author_post_ctx_embed_te = pos_embed + post_time_embeds
        #     # author_post_ctx_embed = self.temp_ctx_attention(author_post_ctx_embed_te, author_post_ctx_embed_te, author_post_ctx_embed_te)  # (t, pos, d)
        #     author_post_ctx_embed = self.temp_ctx_attention(post_time_embeds, pos_embed, pos_embed)  # (t, pos, d)
        #     htemp_embeds.append(author_post_ctx_embed)
        # htemp_loss = self.htemp_loss(token_embed, htemp_embeds)

        # generate loss
        # loss, coherence = self.rank_loss(token_embed, author_tctx_embed, answerers)
        loss, coherence = self.rank_loss(token_embed, author_tctx_embed, answerers, accept_usr)
        # loss += 0.5 * htemp_loss

        # coherence = self.coherence_func(token_embed, None, author_tctx_embed)
        output = {"loss": loss, "coherence": coherence}

        predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1)
        truth = [[j[0] for j in i] for i in answerers]

        # self.rank_recall(predict, truth)
        # self.mrr(predict, truth)
        self.mrr(predict, accept_usr)


        return output
Esempio n. 41
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Esempio n. 42
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                p1_answer_marker: torch.IntTensor = None,
                p2_answer_marker: torch.IntTensor = None,
                p3_answer_marker: torch.IntTensor = None,
                yesno_list: torch.IntTensor = None,
                followup_list: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        p1_answer_marker : ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 0.
            This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length].
            Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer
            in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>.
            For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac
        p2_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 1.
            It is similar to p1_answer_marker, but marking previous previous answer in passage.
        p3_answer_marker :  ``torch.IntTensor``, optional
            This is one of the inputs, but only when num_context_answers > 2.
            It is similar to p1_answer_marker, but marking previous previous previous answer in passage.
        yesno_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (the yes/no/not a yes no question).
        followup_list :  ``torch.IntTensor``, optional
            This is one of the outputs that we are trying to predict.
            Three way classification (followup / maybe followup / don't followup).
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question ID, original passage text, and token
            offsets into the passage for each instance in the batch.  We use this for computing
            official metrics using the official SQuAD evaluation script.  The length of this list
            should be the batch size, and each dictionary should have the keys ``id``,
            ``original_passage``, and ``token_offsets``.  If you only want the best span string and
            don't care about official metrics, you can omit the ``id`` key.

        Returns
        -------
        An output dictionary consisting of the followings.
        Each of the followings is a nested list because first iterates over dialog, then questions in dialog.

        qid : List[List[str]]
            A list of list, consisting of question ids.
        followup : List[List[int]]
            A list of list, consisting of continuation marker prediction index.
            (y :yes, m: maybe follow up, n: don't follow up)
        yesno : List[List[int]]
            A list of list, consisting of affirmation marker prediction index.
            (y :yes, x: not a yes/no question, n: np)
        best_span_str : List[List[str]]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size()
        total_qa_count = batch_size * max_qa_count
        qa_mask = torch.ge(followup_list, 0).view(total_qa_count)
        embedded_question = self._text_field_embedder(question, num_wrapping_dims=1)
        embedded_question = embedded_question.reshape(total_qa_count, max_q_len,
                                                      self._text_field_embedder.get_output_dim())
        embedded_question = self._variational_dropout(embedded_question)
        embedded_passage = self._variational_dropout(self._text_field_embedder(passage))
        passage_length = embedded_passage.size(1)

        question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float()
        question_mask = question_mask.reshape(total_qa_count, max_q_len)
        passage_mask = util.get_text_field_mask(passage).float()

        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1)
        repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length)

        if self._num_context_answers > 0:
            # Encode question turn number inside the dialog into question embedding.
            question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question))
            question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len)
            question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1)
            question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len)
            question_num_marker_emb = self._question_num_marker(question_num_ind)
            embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1)

            # Encode the previous answers in passage embedding.
            repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \
                view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim())
            # batch_size * max_qa_count, passage_length, word_embed_dim
            p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length)
            p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker)
            repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1)
            if self._num_context_answers > 1:
                p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length)
                p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker)
                repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1)
                if self._num_context_answers > 2:
                    p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length)
                    p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker)
                    repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb],
                                                          dim=-1)

            repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage,
                                                                                    repeated_passage_mask))
        else:
            encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask))
            repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1)
            repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count,
                                                                     passage_length,
                                                                     self._encoding_dim)

        encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask))

        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question)
        # Shape: (batch_size * max_qa_count, passage_length, question_length)
        passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask)
        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(passage_question_similarity,
                                                       question_mask.unsqueeze(1),
                                                       -1e7)

        question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1)
        question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask)
        # Shape: (batch_size * max_qa_count, encoding_dim)
        question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count,
                                                                                    passage_length,
                                                                                    self._encoding_dim)

        # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([repeated_encoded_passage,
                                          passage_question_vectors,
                                          repeated_encoded_passage * passage_question_vectors,
                                          repeated_encoded_passage * tiled_question_passage_vector],
                                         dim=-1)

        final_merged_passage = F.relu(self._merge_atten(final_merged_passage))

        residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage,
                                                                          repeated_passage_mask))
        self_attention_matrix = self._self_attention(residual_layer, residual_layer)

        mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \
               * repeated_passage_mask.reshape(total_qa_count, 1, passage_length)
        self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device)
        self_mask = self_mask.reshape(1, passage_length, passage_length)
        mask = mask * (1 - self_mask)

        self_attention_probs = util.masked_softmax(self_attention_matrix, mask)

        # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim)
        self_attention_vecs = torch.matmul(self_attention_probs, residual_layer)
        self_attention_vecs = torch.cat([self_attention_vecs, residual_layer,
                                         residual_layer * self_attention_vecs],
                                        dim=-1)
        residual_layer = F.relu(self._merge_self_attention(self_attention_vecs))

        final_merged_passage = final_merged_passage + residual_layer
        # batch_size * maxqa_pair_len * max_passage_len * 200
        final_merged_passage = self._variational_dropout(final_merged_passage)
        start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask)
        span_start_logits = self._span_start_predictor(start_rep).squeeze(-1)

        end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1),
                                         repeated_passage_mask)
        span_end_logits = self._span_end_predictor(end_rep).squeeze(-1)

        span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1)
        span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        # batch_size * maxqa_len_pair, max_document_len
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)

        best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits,
                                                       span_yesno_logits, span_followup_logits,
                                                       self._max_span_length)

        output_dict: Dict[str, Any] = {}

        # Compute the loss.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1),
                            ignore_index=-1)
            self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask)
            loss += nll_loss(util.masked_log_softmax(span_end_logits,
                                                     repeated_passage_mask), span_end.view(-1), ignore_index=-1)
            self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask)
            self._span_accuracy(best_span[:, 0:2],
                                torch.stack([span_start, span_end], -1).view(total_qa_count, 2),
                                mask=qa_mask.unsqueeze(1).expand(-1, 2).long())
            # add a select for the right span to compute loss
            gold_span_end_loc = []
            span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy()
            for i in range(0, total_qa_count):
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0))
                gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0))
            gold_span_end_loc = span_start.new(gold_span_end_loc)

            pred_span_end_loc = []
            for i in range(0, total_qa_count):
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0))
                pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0))
            predicted_end = span_start.new(pred_span_end_loc)

            _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3)
            loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1)
            loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1)

            _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3)
            self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask)
            self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask)
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []
        output_dict['followup'] = []
        output_dict['yesno'] = []
        best_span_cpu = best_span.detach().cpu().numpy()
        for i in range(batch_size):
            passage_str = metadata[i]['original_passage']
            offsets = metadata[i]['token_offsets']
            f1_score = 0.0
            per_dialog_best_span_list = []
            per_dialog_yesno_list = []
            per_dialog_followup_list = []
            per_dialog_query_id_list = []
            for per_dialog_query_index, (iid, answer_texts) in enumerate(
                    zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])):
                predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index])

                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]

                yesno_pred = predicted_span[2]
                followup_pred = predicted_span[3]
                per_dialog_yesno_list.append(yesno_pred)
                per_dialog_followup_list.append(followup_pred)
                per_dialog_query_id_list.append(iid)

                best_span_string = passage_str[start_offset:end_offset]
                per_dialog_best_span_list.append(best_span_string)
                if answer_texts:
                    if len(answer_texts) > 1:
                        t_f1 = []
                        # Compute F1 over N-1 human references and averages the scores.
                        for answer_index in range(len(answer_texts)):
                            idxes = list(range(len(answer_texts)))
                            idxes.pop(answer_index)
                            refs = [answer_texts[z] for z in idxes]
                            t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                                 best_span_string,
                                                                                 refs))
                        f1_score = 1.0 * sum(t_f1) / len(t_f1)
                    else:
                        f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,
                                                                            best_span_string,
                                                                            answer_texts)
                self._official_f1(100 * f1_score)
            output_dict['qid'].append(per_dialog_query_id_list)
            output_dict['best_span_str'].append(per_dialog_best_span_list)
            output_dict['yesno'].append(per_dialog_yesno_list)
            output_dict['followup'].append(per_dialog_followup_list)
        return output_dict
Esempio n. 43
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.LongTensor,
        tags: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : TextFieldTensors, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer`
            tensors.  At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens":
            Tensor(batch_size, num_tokens)}`. This dictionary will have the same keys as were used
            for the `TokenIndexers` when you created the `TextField` representing your
            sequence.  The dictionary is designed to be passed directly to a `TextFieldEmbedder`,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = None)
            metadata containg the original words in the sentence and the verb to compute the
            frame for, under 'words' and 'verb' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.embedding_dropout(
            self.text_field_embedder(tokens))
        mask = get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(
            verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat(
            [embedded_text_input, embedded_verb_indicator], -1)
        batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size(
        )

        encoded_text = self.encoder(embedded_text_with_verb_indicator, mask)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.decode.
        output_dict["mask"] = mask

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from decode()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of decode() to a separate function.
                batch_bio_predicted_tags = self.decode(output_dict).pop("tags")
                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss

        words, verbs = zip(*[(x["words"], x["verb"]) for x in metadata])
        if metadata is not None:
            output_dict["words"] = list(words)
            output_dict["verb"] = list(verbs)
        return output_dict
Esempio n. 44
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                valid_actions: List[List[ProductionRule]],
                action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer,
        if we're training, or a BeamSearch for inference, if we're not.

        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will
            be passed through a ``TextFieldEmbedder`` and then through an encoder.
        valid_actions : ``List[List[ProductionRule]]``
            A list of all possible actions for each ``World`` in the batch, indexed into a
            ``ProductionRule`` using a ``ProductionRuleField``.  We will embed all of these
            and use the embeddings to determine which action to take at each timestep in the
            decoder.
        target_action_sequence : torch.Tensor, optional (default=None)
            The action sequence for the correct action sequence, where each action is an index into the list
            of possible actions.  This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the
            trailing dimension.
        sql_queries : List[List[str]], optional (default=None)
            A list of the SQL queries that are given during training or validation.
        """
        embedded_utterance = self._utterance_embedder(tokens)
        mask = util.get_text_field_mask(tokens).float()
        batch_size = embedded_utterance.size(0)

        # (batch_size, num_tokens, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask))
        initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions)

        if action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            action_sequence = action_sequence.squeeze(-1)
            target_mask = action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, Any] = {}
        if action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, target_sequence_length)
            # here after we unsqueeze it for the MML trainer.
            loss_output = self._decoder_trainer.decode(initial_state,
                                                       self._transition_function,
                                                       (action_sequence.unsqueeze(1),
                                                        target_mask.unsqueeze(1)))
            outputs.update(loss_output)

        if not self.training:
            action_mapping = []
            for batch_actions in valid_actions:
                batch_action_mapping = {}
                for action_index, action in enumerate(batch_actions):
                    batch_action_mapping[action_index] = action[0]
                action_mapping.append(batch_action_mapping)

            outputs['action_mapping'] = action_mapping
            # This tells the state to start keeping track of debug info, which we'll pass along in
            # our output dictionary.
            initial_state.debug_info = [[] for _ in range(batch_size)]
            best_final_states = self._beam_search.search(self._max_decoding_steps,
                                                         initial_state,
                                                         self._transition_function,
                                                         keep_final_unfinished_states=True)
            outputs['best_action_sequence'] = []
            outputs['debug_info'] = []
            outputs['predicted_sql_query'] = []
            outputs['sql_queries'] = []
            for i in range(batch_size):
                # Decoding may not have terminated with any completed valid SQL queries, if `num_steps`
                # isn't long enough (or if the model is not trained enough and gets into an
                # infinite action loop).
                if i not in best_final_states:
                    self._exact_match(0)
                    self._denotation_accuracy(0)
                    self._valid_sql_query(0)
                    self._action_similarity(0)
                    outputs['predicted_sql_query'].append('')
                    continue

                best_action_indices = best_final_states[i][0].action_history[0]

                action_strings = [action_mapping[i][action_index]
                                  for action_index in best_action_indices]

                predicted_sql_query = action_sequence_to_sql(action_strings)
                if action_sequence is not None:
                    # Use a Tensor, not a Variable, to avoid a memory leak.
                    targets = action_sequence[i].data
                    sequence_in_targets = 0
                    sequence_in_targets = self._action_history_match(best_action_indices, targets)
                    self._exact_match(sequence_in_targets)

                    similarity = difflib.SequenceMatcher(None, best_action_indices, targets)
                    self._action_similarity(similarity.ratio())

                outputs['best_action_sequence'].append(action_strings)
                outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True))
                outputs['debug_info'].append(best_final_states[i][0].debug_info[0])  # type: ignore
        return outputs
    def _get_initial_rnn_and_grammar_state(self,
                                           question: Dict[str, torch.LongTensor],
                                           table: Dict[str, torch.LongTensor],
                                           world: List[WikiTablesWorld],
                                           actions: List[List[ProductionRule]],
                                           outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet],
                                                                             List[LambdaGrammarStatelet]]:
        """
        Encodes the question and table, computes a linking between the two, and constructs an
        initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the
        decoder.

        We take ``outputs`` as a parameter here and `modify` it, adding things that we want to
        visualize in a demo.
        """
        table_text = table['text']
        # (batch_size, question_length, embedding_dim)
        embedded_question = self._question_embedder(question)
        question_mask = util.get_text_field_mask(question).float()
        # (batch_size, num_entities, num_entity_tokens, embedding_dim)
        embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
        table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()

        batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
        num_question_tokens = embedded_question.size(1)

        # (batch_size, num_entities, embedding_dim)
        encoded_table = self._entity_encoder(embedded_table, table_mask)
        # (batch_size, num_entities, num_neighbors)
        neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table)

        # Neighbor_indices is padded with -1 since 0 is a potential neighbor index.
        # Thus, the absolute value needs to be taken in the index_select, and 1 needs to
        # be added for the mask since that method expects 0 for padding.
        # (batch_size, num_entities, num_neighbors, embedding_dim)
        embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices))

        neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1},
                                                 num_wrapping_dims=1).float()

        # Encoder initialized to easily obtain a masked average.
        neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True))
        # (batch_size, num_entities, embedding_dim)
        embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask)

        # entity_types: tensor with shape (batch_size, num_entities), where each entry is the
        # entity's type id.
        # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index
        # These encode the same information, but for efficiency reasons later it's nice
        # to have one version as a tensor and one that's accessible on the cpu.
        entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table)

        entity_type_embeddings = self._entity_type_encoder_embedding(entity_types)
        projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float())
        # (batch_size, num_entities, embedding_dim)
        entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings)


        # Compute entity and question word similarity.  We tried using cosine distance here, but
        # because this similarity is the main mechanism that the model can use to push apart logit
        # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger
        # output range than [-1, 1].
        question_entity_similarity = torch.bmm(embedded_table.view(batch_size,
                                                                   num_entities * num_entity_tokens,
                                                                   self._embedding_dim),
                                               torch.transpose(embedded_question, 1, 2))

        question_entity_similarity = question_entity_similarity.view(batch_size,
                                                                     num_entities,
                                                                     num_entity_tokens,
                                                                     num_question_tokens)

        # (batch_size, num_entities, num_question_tokens)
        question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2)

        # (batch_size, num_entities, num_question_tokens, num_features)
        linking_features = table['linking']

        linking_scores = question_entity_similarity_max_score

        if self._use_neighbor_similarity_for_linking:
            # The linking score is computed as a linear projection of two terms. The first is the
            # maximum similarity score over the entity's words and the question token. The second
            # is the maximum similarity over the words in the entity's neighbors and the question
            # token.
            #
            # The second term, projected_question_neighbor_similarity, is useful when a column
            # needs to be selected. For example, the question token might have no similarity with
            # the column name, but is similar with the cells in the column.
            #
            # Note that projected_question_neighbor_similarity is intended to capture the same
            # information as the related_column feature.
            #
            # Also note that this block needs to be _before_ the `linking_params` block, because
            # we're overwriting `linking_scores`, not adding to it.

            # (batch_size, num_entities, num_neighbors, num_question_tokens)
            question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score,
                                                                     torch.abs(neighbor_indices))
            # (batch_size, num_entities, num_question_tokens)
            question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2)
            projected_question_entity_similarity = self._question_entity_params(
                    question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            projected_question_neighbor_similarity = self._question_neighbor_params(
                    question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1)
            linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity

        feature_scores = None
        if self._linking_params is not None:
            feature_scores = self._linking_params(linking_features).squeeze(3)
            linking_scores = linking_scores + feature_scores

        # (batch_size, num_question_tokens, num_entities)
        linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2),
                                                                question_mask, entity_type_dict)

        # (batch_size, num_question_tokens, embedding_dim)
        link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities)
        encoder_input = torch.cat([link_embedding, embedded_question], 2)

        # (batch_size, question_length, encoder_output_dim)
        encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask))

        # This will be our initial hidden state and memory cell for the decoder LSTM.
        final_encoder_output = util.get_final_encoder_states(encoder_outputs,
                                                             question_mask,
                                                             self._encoder.is_bidirectional())
        memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim())

        # To make grouping states together in the decoder easier, we convert the batch dimension in
        # all of our tensors into an outer list.  For instance, the encoder outputs have shape
        # `(batch_size, question_length, encoder_output_dim)`.  We need to convert this into a list
        # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`.  Then we
        # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s.
        encoder_output_list = [encoder_outputs[i] for i in range(batch_size)]
        question_mask_list = [question_mask[i] for i in range(batch_size)]
        initial_rnn_state = []
        for i in range(batch_size):
            initial_rnn_state.append(RnnStatelet(final_encoder_output[i],
                                                 memory_cell[i],
                                                 self._first_action_embedding,
                                                 self._first_attended_question,
                                                 encoder_output_list,
                                                 question_mask_list))
        initial_grammar_state = [self._create_grammar_state(world[i],
                                                            actions[i],
                                                            linking_scores[i],
                                                            entity_types[i])
                                 for i in range(batch_size)]
        if not self.training:
            # We add a few things to the outputs that will be returned from `forward` at evaluation
            # time, for visualization in a demo.
            outputs['linking_scores'] = linking_scores
            if feature_scores is not None:
                outputs['feature_scores'] = feature_scores
            outputs['similarity_scores'] = question_entity_similarity_max_score
        return initial_rnn_state, initial_grammar_state
Esempio n. 46
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None  # pylint:disable=unused-argument
               ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.

        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        # apply dropout for LSTM
        if self.rnn_input_dropout:
            embedded_premise = self.rnn_input_dropout(embedded_premise)
            embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis)

        # encode premise and hypothesis
        encoded_premise = self._encoder(embedded_premise, premise_mask)
        encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask)

        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(encoded_premise, h2p_attention)

        # the "enhancement" layer
        premise_enhanced = torch.cat(
                [encoded_premise, attended_hypothesis,
                 encoded_premise - attended_hypothesis,
                 encoded_premise * attended_hypothesis],
                dim=-1
        )
        hypothesis_enhanced = torch.cat(
                [encoded_hypothesis, attended_premise,
                 encoded_hypothesis - attended_premise,
                 encoded_hypothesis * attended_premise],
                dim=-1
        )

        # The projection layer down to the model dimension.  Dropout is not applied before
        # projection.
        projected_enhanced_premise = self._projection_feedforward(premise_enhanced)
        projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced)

        # Run the inference layer
        if self.rnn_input_dropout:
            projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise)
            projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis)
        v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask)
        v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask)

        # The pooling layer -- max and avg pooling.
        # (batch_size, model_dim)
        v_a_max, _ = replace_masked_values(
                v_ai, premise_mask.unsqueeze(-1), -1e7
        ).max(dim=1)
        v_b_max, _ = replace_masked_values(
                v_bi, hypothesis_mask.unsqueeze(-1), -1e7
        ).max(dim=1)

        v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum(
                premise_mask, 1, keepdim=True
        )
        v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum(
                hypothesis_mask, 1, keepdim=True
        )

        # Now concat
        # (batch_size, model_dim * 2 * 4)
        v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1)

        # the final MLP -- apply dropout to input, and MLP applies to output & hidden
        if self.dropout:
            v_all = self.dropout(v_all)

        output_hidden = self._output_feedforward(v_all)
        label_logits = self._output_logit(output_hidden)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits, "label_probs": label_probs}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        return output_dict
Esempio n. 47
0
    def forward(self,
                text,
                spans,
                ner_labels,
                coref_labels,
                relation_labels,
                trigger_labels,
                argument_labels,
                metadata,
                span_labels):

        # Shape: (batch_size, max_sentence_length, bert_size)
        text_embeddings = self._text_field_embedder(text)

        text_embeddings = self._lexical_dropout(text_embeddings)

        # Shape: (batch_size, max_sentence_length)
        text_mask = util.get_text_field_mask(text).float()

        sentence_lengths = 0*text_mask.sum(dim=1).long()
        for i in range(len(metadata)):
            sentence_lengths[i] = metadata[i]["end_ix"] - metadata[i]["start_ix"]

        # Shape: (batch_size, max_sentence_length, encoding_dim)
        contextualized_embeddings = self._lstm_dropout(self._context_layer(text_embeddings, text_mask))
        assert spans.max() < contextualized_embeddings.shape[1]

        span_mask = (spans[:, :, 0] >= 0).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans, text_mask)

        # Make calls out to the modules to get results.
        output_ner = {'loss': 0}
        output_span = {'loss': 0}

        # Make predictions and compute losses for each module
        if self._loss_weights['span'] > 0:
            output_span = self._span(spans, span_mask, span_embeddings, sentence_lengths, span_labels, metadata)

        if self._loss_weights['ner'] > 0:
            output_ner = self._ner(
                spans, span_mask, span_embeddings, sentence_lengths, ner_labels, metadata, output_span)

        loss = (
                self._loss_weights['ner'] * output_ner['loss'] +
                self._loss_weights['span'] * output_span['loss']
                )

        output_dict = dict(ner=output_ner, span=output_span)
        output_dict['loss'] = loss

        return output_dict
Esempio n. 48
0
    def forward(
            self,
            sentence: Dict[str, torch.Tensor],
            labels: torch.Tensor = None,
            labels_aspect: torch.Tensor = None,
            domain: torch.Tensor = None,
            sample_weight: torch.Tensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        pos_weight = None
        if 'label_prior' in metadata[0]:
            # prior on labels
            pos_weight = torch.from_numpy(metadata[0]['label_prior']).float()

        if 'label_prior_aspect' in metadata[0]:
            # prior on labels
            pos_weight_aspect = [None for _ in range(2)]
            for dom_idx in range(2):
                try:
                    idx = int((domain == dom_idx).nonzero()[0])
                except IndexError:
                    continue
                w = metadata[idx]['label_prior_aspect']
                pos_weight_aspect[dom_idx] = torch.from_numpy(w).float()

        labels = labels[:, :NUM_MOTIVES]  # Delete ``NULL'' column

        embeddings = self.word_embedder(sentence)

        mask = get_text_field_mask(sentence)
        encoder_out = self.encoder(embeddings, mask)

        label_logits = self.classifier_feedforward(encoder_out)
        output = {'label_logits': label_logits}

        if labels is not None:
            label_ids = (label_logits.sign().long() + 1) / 2
            self._macro_f1(label_ids, labels)
            output['loss'] = self._loss_func(label_logits,
                                             labels.float(),
                                             weight=sample_weight,
                                             pos_weight=pos_weight)

        if labels_aspect is not None:
            components = [(self.classifier_aspect_res, self._f1_aspect_res),
                          (self.classifier_aspect_lap, self._f1_aspect_lap)]
            for dom_idx, (clf, f1) in enumerate(components):
                h = encoder_out[domain == dom_idx]  # encoded sentences
                if h.shape[0] == 0:  # no instance
                    continue
                aspect_logits = clf(h)
                aspects_pred = (aspect_logits.sign().long() + 1) / 2
                aspects = labels_aspect[domain ==
                                        dom_idx][:, :aspect_logits.shape[1]]
                f1(aspects_pred, aspects)
                output['loss'] += self._loss_func(
                    aspect_logits,
                    aspects.float(),
                    weight=sample_weight[domain == dom_idx],
                    pos_weight=pos_weight_aspect[dom_idx])

        return output
Esempio n. 49
0
    def forward(
        self,
        messages: Dict[str, torch.Tensor],
        # (batch_size, n_turns, n_facts, n_words)
        facts: Dict[str, torch.Tensor],
        # (batch_size, n_turns)
        senders: torch.Tensor,
        # (batch_size, n_turns, n_acts)
        dialog_acts: torch.Tensor,
        # (batch_size, n_turns)
        dialog_acts_mask: torch.Tensor,
        # (batch_size, n_entities)
        known_entities: Dict[str, torch.Tensor],
        # (batch_size, 1)
        focus_entity: Dict[str, torch.Tensor],
        # (batch_size, n_turns, n_facts)
        fact_labels: Optional[torch.Tensor] = None,
        # (batch_size, n_turns, 2)
        likes: Optional[torch.Tensor] = None,
        metadata: Optional[Dict] = None,
    ):
        output = {}
        # Take care of the easy stuff first

        # (batch_size, n_entities)
        known_entities_mask = get_text_field_mask(known_entities)

        # (batch_size, n_turns, sender_emb_size)
        sender_emb = self._sender_emb(senders)

        known_emb = self._mention_embedder(known_entities)
        # TODO: This could instead of averaged, be attended
        known_vec = self._known_net(
            masked_mean(known_emb, known_entities_mask.unsqueeze(-1), dim=1))
        # There is always exactly one entity
        focus_emb = self._focus_net(
            self._mention_embedder(focus_entity)[:, 0, :])

        if self._use_bert:
            # (batch_size, n_turns, n_words, emb_dim)
            context, utter_mask = self._bert_encoder(messages)
            context = self._dropout(context)
        else:
            # (batch_size, n_turns)
            # This is the mask since not all dialogs have same number
            # of turns
            utter_mask = get_text_field_mask(messages)

            # (batch_size, n_turns, n_words)
            # Mask since not all utterances have same number of words
            # Wrapping dim skips over n_messages dim
            text_mask = get_text_field_mask(messages, num_wrapping_dims=1)
            # (batch_size, n_turns, n_words, emb_dim)
            embed = self._dropout(self._utter_embedder(messages))
            # (batch_size, n_turns, hidden_dim)
            context = self._dist_utter_context(embed, text_mask)

        # (batch_size, n_turns, act_emb_size)
        act_emb = self._act_embedder(dialog_acts.float())
        act_emb = self._clamp_dialog_acts(act_emb)

        # (batch_size, n_turns, hidden_dim + known_dim + focus_dim + sender_dim + act_dim)
        n_turns = context.shape[1]
        full_context = torch.cat(
            (
                context,
                sender_emb,
                act_emb,
                focus_emb[:, None, :].repeat_interleave(n_turns, 1),
                known_vec[:, None, :].repeat_interleave(n_turns, 1),
            ),
            dim=-1,
        )

        # (batch_size, n_turns, hidden_dim)
        # This assumes dialog_context does not peek into future
        dialog_context = self._dialog_context(full_context, utter_mask)

        # shift context one right, pad with zeros at front
        # This makes it so that utter_t is paired with context_t-1
        # which is what we want
        # This is useful in a few different places, so compute it here once
        shape = dialog_context.shape
        shifted_context = torch.cat(
            (
                dialog_context.new_zeros([shape[0], 1, shape[2]]),
                dialog_context[:, :-1, :],
            ),
            dim=1,
        )
        has_loss = False

        if self._disable_dialog_acts:
            da_loss = 0
            policy_loss = 0
        else:
            # Dialog act per utter loss
            has_loss = True
            da_loss = self._compute_da_loss(
                output,
                context,
                shifted_context,
                utter_mask,
                dialog_acts,
                dialog_acts_mask,
            )
            # Policy loss
            policy_loss = self._compute_policy_loss(output, shifted_context,
                                                    utter_mask, dialog_acts,
                                                    dialog_acts_mask)

        if self._disable_facts:
            # If facts are disabled, don't output anything related
            # to them
            fact_loss = 0
        else:
            if self._use_bert:
                # (batch_size, n_turns, n_words, emb_dim)
                fact_repr, fact_mask = self._bert_encoder(facts)
                fact_repr = self._dropout(fact_repr)
                fact_mask[:, ::2] = 0
            else:
                # (batch_size, n_turns, n_facts)
                # Wrapping dim skips over n_messages
                fact_mask = get_text_field_mask(facts, num_wrapping_dims=1)
                # In addition to masking padded facts, also explicitly mask
                # user turns just in case
                fact_mask[:, ::2] = 0

                # (batch_size, n_turns, n_facts, n_words)
                # Wrapping dim skips over n_turns and n_facts
                fact_text_mask = get_text_field_mask(facts,
                                                     num_wrapping_dims=2)
                # (batch_size, n_turns, n_facts, n_words, emb_dim)
                # Share encoder with utter encoder
                # Again, stupid dimensions
                fact_embed = self._dropout(self._utter_embedder(facts))
                shape = fact_embed.shape
                word_dim = shape[-2]
                emb_dim = shape[-1]
                reshaped_facts = fact_embed.view(-1, word_dim, emb_dim)
                reshaped_fact_text_mask = fact_text_mask.view(-1, word_dim)
                reshaped_fact_repr = self._utter_context(
                    reshaped_facts, reshaped_fact_text_mask)
                # No more emb dimension or word/seq dim
                fact_repr = reshaped_fact_repr.view(shape[:-2] + (-1, ))

            fact_logits = self._fact_ranker(
                shifted_context,
                fact_repr,
            )
            output["fact_logits"] = fact_logits
            if fact_labels is not None:
                has_loss = True
                fact_loss = self._compute_fact_loss(fact_logits, fact_labels,
                                                    fact_mask)
                self._fact_loss_metric(fact_loss.item())
                self._fact_mrr(fact_logits, fact_labels, mask=fact_mask)
            else:
                fact_loss = 0

        if self._disable_likes:
            like_loss = 0
        else:
            has_loss = True
            # (batch_size, n_turns, 2)
            like_logits = self._like_classifier(dialog_context)
            output["like_logits"] = like_logits

            # There are several masks here to get the loss/metrics correct
            # - utter_mask: mask out positions that do not have an utterance
            # - user_mask: mask out positions that have a user utterances
            #              since their turns are never liked
            # Using new_ones() preserves the type of the tensor
            user_mask = utter_mask.new_ones(utter_mask.shape)

            # Since the user is always even, this masks out user positions
            user_mask[:, ::2] = 0
            final_mask = utter_mask * user_mask
            masked_likes = likes * final_mask
            if likes is not None:
                has_loss = True
                like_loss = sequence_cross_entropy_with_logits(
                    like_logits, masked_likes, final_mask)
                self._like_accuracy(like_logits, masked_likes, final_mask)
                self._like_loss_metric(like_loss.item())
            else:
                like_loss = 0

        if has_loss:
            output["loss"] = (self._fact_loss_weight * fact_loss + like_loss +
                              da_loss + policy_loss)

        return output
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                spans: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                pos_tags: Dict[str, torch.LongTensor] = None,
                span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        spans : ``torch.LongTensor``, required.
            A tensor of shape ``(batch_size, num_spans, 2)`` representing the
            inclusive start and end indices of all possible spans in the sentence.
        metadata : List[Dict[str, Any]], required.
            A dictionary of metadata for each batch element which has keys:
                tokens : ``List[str]``, required.
                    The original string tokens in the sentence.
                gold_tree : ``nltk.Tree``, optional (default = None)
                    Gold NLTK trees for use in evaluation.
                pos_tags : ``List[str]``, optional.
                    The POS tags for the sentence. These can be used in the
                    model as embedded features, but they are passed here
                    in addition for use in constructing the tree.
        pos_tags : ``torch.LongTensor``, optional (default = None)
            The output of a ``SequenceLabelField`` containing POS tags.
        span_labels : ``torch.LongTensor``, optional (default = None)
            A torch tensor representing the integer gold class labels for all possible
            spans, of shape ``(batch_size, num_spans)``.

        Returns
        -------
        An output dictionary consisting of:
        class_probabilities : ``torch.FloatTensor``
            A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)``
            representing a distribution over the label classes per span.
        spans : ``torch.LongTensor``
            The original spans tensor.
        tokens : ``List[List[str]]``, required.
            A list of tokens in the sentence for each element in the batch.
        pos_tags : ``List[List[str]]``, required.
            A list of POS tags in the sentence for each element in the batch.
        num_spans : ``torch.LongTensor``, required.
            A tensor of shape (batch_size), representing the lengths of non-padded spans
            in ``enumerated_spans``.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        if pos_tags is not None and self.pos_tag_embedding is not None:
            embedded_pos_tags = self.pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self.pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(tokens)
        # Looking at the span start index is enough to know if
        # this is padding or not. Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long()
        if span_mask.dim() == 1:
            # This happens if you use batch_size 1 and encounter
            # a length 1 sentence in PTB, which do exist. -.-
            span_mask = span_mask.unsqueeze(-1)
        if span_labels is not None and span_labels.dim() == 1:
            span_labels = span_labels.unsqueeze(-1)

        num_spans = get_lengths_from_binary_sequence_mask(span_mask)

        encoded_text = self.encoder(embedded_text_input, mask)
        span_representations = self.span_extractor(encoded_text, spans, mask, span_mask)
        if self.feedforward_layer is not None:
            span_representations = self.feedforward_layer(span_representations)
        logits = self.tag_projection_layer(span_representations)
        class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1))

        output_dict = {
                "class_probabilities": class_probabilities,
                "spans": spans,
                "tokens": [meta["tokens"] for meta in metadata],
                "pos_tags": [meta.get("pos_tags") for meta in metadata],
                "num_spans": num_spans
        }
        if span_labels is not None:
            loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask)
            self.tag_accuracy(class_probabilities, span_labels, span_mask)
            output_dict["loss"] = loss

        # The evalb score is expensive to compute, so we only compute
        # it for the validation and test sets.
        batch_gold_trees = [meta.get("gold_tree") for meta in metadata]
        if all(batch_gold_trees) and self._evalb_score is not None and not self.training:
            gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1]
                                              for tree in batch_gold_trees]
            predicted_trees = self.construct_trees(class_probabilities.cpu().data,
                                                   spans.cpu().data,
                                                   num_spans.data,
                                                   output_dict["tokens"],
                                                   gold_pos_tags)
            self._evalb_score(predicted_trees, batch_gold_trees)

        return output_dict
Esempio n. 51
0
    def forward(self,
                tokens: Dict[str, torch.Tensor],
                label: torch.Tensor = None,
                task_id=None) -> Dict[str, torch.Tensor]:
        if (task_id is None):
            task_id = self.get_current_taskid()
        hidden2tag = self.classification_layers[task_id]
        if torch.cuda.is_available():
            tokens = move_to_device(tokens, torch.cuda.current_device())
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)

        if self.args.position_embed:
            embeddings = self.pos_embedding(embeddings)

        # Task embedder add task id
        if self.task_embedder:
            bs, seq, edi = embeddings.shape
            task_em = torch.randn((bs, seq, 1))
            task_em.fill_(self.get_current_taskid())
            if torch.cuda.is_available():
                task_em = move_to_device(task_em, torch.cuda.current_device())
            embeddings = torch.cat([embeddings, task_em], dim=-1)

        # Task encoding adds the id using transformer style.
        if self.task_encoder:
            embeddings = self.task_encoder(embeddings,
                                           self.get_current_taskid())

        # Task projection adds the id using single linear layer projection from one hot encoding
        if self.task_projection:
            embeddings = self.task_projection(embeddings,
                                              self.get_current_taskid())

        # Increase temperature at embedding layer
        if (self.args.all_temp or self.args.emb_temp) and self.inv_temp:
            embeddings = self.inv_temp * embeddings

        if type(self.encoder) == HashedMemoryRNN:
            output = self.encoder(embeddings, mask, mem_tokens=tokens)
        else:
            output = self.encoder(embeddings, mask)
        if type(output) == tuple:
            encoder_out, activations = output
            activations = encoder_out
        else:
            encoder_out = output
            activations = output
        self.activations = activations
        self.labels = label

        # Increase temperature at encoder layer
        if (self.args.all_temp or self.args.enc_temp) and self.inv_temp:
            encoder_out = self.inv_temp * encoder_out

        if self.use_task_memory:
            encoder_out = self.task_memory(encoder_out, self.current_task)
        tag_logits = hidden2tag(encoder_out)

        # Increase temperature softmax layer
        if (self.args.all_temp or self.args.softmax_temp) and self.inv_temp:
            tag_logits = self.inv_temp * tag_logits

        output = {'logits': tag_logits, 'encoder_output': encoder_out}
        if label is not None:
            _, preds = tag_logits.max(dim=1)
            self.average(
                matthews_corrcoef(label.data.cpu().numpy(),
                                  preds.data.cpu().numpy()))
            self.micro_avg(
                f1_score(label.data.cpu().numpy(),
                         preds.data.cpu().numpy(),
                         average='micro'))
            self.accuracy(tag_logits, label)
            output["loss"] = self.loss_function(tag_logits, label)
            if self.use_task_memory:
                output["loss"] += self.task_memory.get_memory_loss(
                    self.current_task)
            if (self.args.ewc or self.args.oewc) and self.training:
                output[
                    "ewc_loss"] = self.args.ewc_importance * self.ewc.penalty(
                        self.get_current_taskid())
                output["loss"] += output["ewc_loss"]
                output["loss"].backward(retain_graph=True)
                if self._len_dataset:
                    self.ewc.update_penalty(self.task2id[self.current_task],
                                            self, self._len_dataset)
        return output
    def forward(
            self,  # type: ignore
            utterance: Dict[str, torch.LongTensor],
            logical_forms: Dict[str, torch.LongTensor],
            utterance_string: List[str],
            logical_form_strings: List[List[str]]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------

        Returns
        -------

        """
        # (batch_size, num_utterance_tokens, utterance_embedding_dim)
        embedded_utterance = self.utterance_embedder(utterance)
        utterance_mask = util.get_text_field_mask(utterance)
        encoded_utterance = self.utterance_encoder(embedded_utterance,
                                                   utterance_mask)
        # Because we're just summing everything in the end, we can do the sum upfront to save some
        # time.
        # (batch_size, utterance_embedding_dim)
        encoded_utterance = encoded_utterance.sum(dim=1)
        # (batch_size, num_logical_forms, num_lf_tokens, lf_embedding_dim)
        embedded_logical_forms = self.logical_form_embedder(
            logical_forms, num_wrapping_dims=1)
        # (batch_size, num_logical_forms, num_lf_tokens)
        logical_form_token_mask = util.get_text_field_mask(logical_forms,
                                                           num_wrapping_dims=1)
        # (batch_size, num_logical_forms)
        logical_form_mask = logical_form_token_mask.sum(dim=-1).clamp(max=1)
        # (batch_size, num_logical_forms, lf_embedding_dim)
        encoded_logical_forms = embedded_logical_forms.sum(dim=2)
        # (batch_size, num_logical_forms, utterance_embedding_dim)
        predicted_embeddings = self.translation_layer(encoded_logical_forms)

        # (batch_size, num_logical_forms)
        similarities = torch.nn.functional.cosine_similarity(
            predicted_embeddings, encoded_utterance.unsqueeze(1), dim=2)
        # to avoid division by zero, add a 1
        logical_form_lens = 1.0 + torch.sum(
            logical_form_token_mask, dim=-1, dtype=similarities.dtype)
        if self.normalize_by_len:
            similarities = similarities / logical_form_lens
            # Make sure masked logical forms aren't included in the max.
        similarities = util.replace_masked_values(similarities,
                                                  logical_form_mask, -1e7)

        ranks = (similarities[:, 0].unsqueeze(1) < similarities)
        curr_ranks = ranks.sum(dim=-1)  # (32,) ranks
        hits = [(curr_ranks < k).sum().cpu().data.numpy() for k in [3, 5, 10]]
        self.hits3 += hits[0]
        self.hits5 += hits[1]
        self.hits10 += hits[2]
        self.mean_ranks += curr_ranks.sum(dim=0).cpu().data.numpy()
        self.batches += ranks.shape[0]

        max_similarity, most_similar = similarities.max(dim=-1)
        loss = (1 - max_similarity).sum()

        self.accuracy += (most_similar == 0).sum().cpu().data.numpy()

        most_similar_strings = []
        for instance_most_similar, instance_logical_forms in zip(
                most_similar.tolist(), logical_form_strings):
            most_similar_strings.append(
                instance_logical_forms[instance_most_similar])
        return {
            "loss": loss,
            "most_similar": most_similar_strings,
            "utterance": utterance_string,
            "all_similarities": similarities
        }
Esempio n. 53
0
    def forward(
            self,  # type: ignore
            tokens: TextFieldTensors,
            tags: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : ``TextFieldTensors``, required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is : ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : ``torch.LongTensor``, optional (default = ``None``)
            A torch tensor representing the sequence of integer gold class labels of shape
            ``(batch_size, num_tokens)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence to be tagged under a 'words' key.

        # Returns

        An output dictionary consisting of:

        logits : ``torch.FloatTensor``
            The logits that are the output of the ``tag_projection_layer``
        mask : ``torch.LongTensor``
            The text field mask for the input tokens
        tags : ``List[List[int]]``
            The predicted tags using the Viterbi algorithm.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        mask = util.get_text_field_mask(tokens)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, mask)

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        if self._feedforward is not None:
            encoded_text = self._feedforward(encoded_text)

        logits = self.tag_projection_layer(encoded_text)
        best_paths = self.crf.viterbi_tags(logits, mask, top_k=self.top_k)

        # Just get the top tags and ignore the scores.
        predicted_tags = cast(List[List[int]], [x[0][0] for x in best_paths])

        output = {"logits": logits, "mask": mask, "tags": predicted_tags}

        if self.top_k > 1:
            output["top_k_tags"] = best_paths

        if tags is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self.crf(logits, tags, mask)

            output["loss"] = -log_likelihood

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = logits * 0.0
            for i, instance_tags in enumerate(predicted_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            for metric in self.metrics.values():
                metric(class_probabilities, tags, mask.float())
            if self.calculate_span_f1:
                self._f1_metric(class_probabilities, tags, mask.float())
        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]
        return output
Esempio n. 54
0
    def forward(self,  # type: ignore
                premise: Dict[str, torch.LongTensor],
                hypothesis: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        premise : Dict[str, torch.LongTensor]
            From a ``TextField``
        hypothesis : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional, (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:

        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log
            probabilities of the entailment label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the
            entailment label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_premise = self._text_field_embedder(premise)
        embedded_hypothesis = self._text_field_embedder(hypothesis)
        premise_mask = get_text_field_mask(premise).float()
        hypothesis_mask = get_text_field_mask(hypothesis).float()

        if self._premise_encoder:
            embedded_premise = self._premise_encoder(embedded_premise, premise_mask)
        if self._hypothesis_encoder:
            embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask)

        projected_premise = self._attend_feedforward(embedded_premise)
        projected_hypothesis = self._attend_feedforward(embedded_hypothesis)
        # Shape: (batch_size, premise_length, hypothesis_length)
        similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis)

        # Shape: (batch_size, premise_length, hypothesis_length)
        p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask)
        # Shape: (batch_size, premise_length, embedding_dim)
        attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention)

        # Shape: (batch_size, hypothesis_length, premise_length)
        h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask)
        # Shape: (batch_size, hypothesis_length, embedding_dim)
        attended_premise = weighted_sum(embedded_premise, h2p_attention)

        premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1)
        hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1)

        compared_premise = self._compare_feedforward(premise_compare_input)
        compared_premise = compared_premise * premise_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_premise = compared_premise.sum(dim=1)

        compared_hypothesis = self._compare_feedforward(hypothesis_compare_input)
        compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1)
        # Shape: (batch_size, compare_dim)
        compared_hypothesis = compared_hypothesis.sum(dim=1)

        aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1)
        label_logits = self._aggregate_feedforward(aggregate_input)
        label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        output_dict = {"label_logits": label_logits,
                       "label_probs": label_probs,
                       "h2p_attention": h2p_attention,
                       "p2h_attention": p2h_attention}

        if label is not None:
            loss = self._loss(label_logits, label.long().view(-1))
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata]
            output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata]

        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            passage_sem_views_q: torch.IntTensor = None,
            passage_sem_views_k: torch.IntTensor = None,
            question_sem_views_q: torch.IntTensor = None,
            question_sem_views_k: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        passage_sem_views_q : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Query (Q)
        passage_sem_views_k : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Key (K)
        question_sem_views_q : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Query (Q)
        question_sem_views_k : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Key (K)

        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        return_output_metadata = self.return_output_metadata
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()

        if isinstance(self._phrase_layer, QaNetSemanticFlatEncoder) \
            or isinstance(self._phrase_layer, QaNetSemanticFlatConcatEncoder)\
            or isinstance(self._modeling_layer, QaNetSemanticFlatEncoder) \
            or isinstance(self._modeling_layer, QaNetSemanticFlatConcatEncoder):
            if passage_sem_views_q is not None:
                passage_sem_views_q = passage_sem_views_q.long()

            if passage_sem_views_k is not None:
                passage_sem_views_k = passage_sem_views_k.long()

            if question_sem_views_q is not None:
                question_sem_views_q = question_sem_views_q.long()

            if question_sem_views_k is not None:
                question_sem_views_k = question_sem_views_k.long()

        if torch.cuda.is_available():
            # indices
            question_mask = to_cuda(question_mask, move_to_cuda=True)
            passage_mask = to_cuda(passage_mask, move_to_cuda=True)

            question = {
                k: to_cuda(v, move_to_cuda=True)
                for k, v in question.items()
            }
            passage = {
                k: to_cuda(v, move_to_cuda=True)
                for k, v in passage.items()
            }

            # span
            if span_start is not None:
                span_start = to_cuda(span_start, move_to_cuda=True)

            if span_end is not None:
                span_end = to_cuda(span_end, move_to_cuda=True)

            # semantic views
            if passage_sem_views_q is not None:
                passage_sem_views_q = to_cuda(passage_sem_views_q,
                                              move_to_cuda=True)

            if passage_sem_views_k is not None:
                passage_sem_views_k = to_cuda(passage_sem_views_k,
                                              move_to_cuda=True)

            if question_sem_views_q is not None:
                question_sem_views_q = to_cuda(question_sem_views_q,
                                               move_to_cuda=True)

            if question_sem_views_k is not None:
                question_sem_views_k = to_cuda(question_sem_views_k,
                                               move_to_cuda=True)

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(
            self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(
            self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(
            embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(
            embedded_passage)

        encoded_passage_output_metadata = None
        encoded_question_output_metadata = None
        if isinstance(self._phrase_layer, QaNetSemanticFlatEncoder) \
                or  isinstance(self._phrase_layer, QaNetSemanticFlatConcatEncoder):
            if is_output_meta_supported(self._phrase_layer):
                encoded_passage, encoded_passage_output_metadata = self._phrase_layer(
                    projected_embedded_passage, passage_sem_views_q,
                    passage_sem_views_k, passage_mask, return_output_metadata)
                encoded_passage = self._dropout(encoded_passage)

                encoded_question, encoded_question_output_metadata = self._phrase_layer(
                    projected_embedded_question, question_sem_views_q,
                    question_sem_views_k, question_mask,
                    return_output_metadata)
                encoded_question = self._dropout(encoded_question)
            else:
                encoded_passage = self._dropout(
                    self._phrase_layer(projected_embedded_passage,
                                       passage_sem_views_q,
                                       passage_sem_views_k, passage_mask))
                encoded_question = self._dropout(
                    self._phrase_layer(projected_embedded_question,
                                       question_sem_views_q,
                                       question_sem_views_k, question_mask))
        else:
            encoded_passage = self._dropout(
                self._phrase_layer(projected_embedded_passage, passage_mask))
            encoded_question = self._dropout(
                self._phrase_layer(projected_embedded_question, question_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True)

        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2),
            passage_mask,
            memory_efficient=True)

        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention,
                                             question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage,
                                                    attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat([
                encoded_passage, passage_question_vectors,
                encoded_passage * passage_question_vectors,
                encoded_passage * passage_passage_vectors
            ],
                      dim=-1))

        modeled_passage_list = [
            self._modeling_proj_layer(merged_passage_attention_vectors)
        ]
        modeled_passage_output_metadata_list = {}

        for modeling_layer_id in range(3):
            modeled_passage_output_metadata = None
            if isinstance(self._modeling_layer, QaNetSemanticFlatEncoder) \
                    or isinstance(self._modeling_layer, QaNetSemanticFlatConcatEncoder):
                if is_output_meta_supported(self._modeling_layer):
                    modeled_passage, modeled_passage_output_metadata = self._modeling_layer(
                        modeled_passage_list[-1], passage_sem_views_q,
                        passage_sem_views_k, passage_mask,
                        return_output_metadata)
                else:
                    modeled_passage = self._modeling_layer(
                        modeled_passage_list[-1], passage_sem_views_q,
                        passage_sem_views_k, passage_mask)

            else:
                modeled_passage = self._modeling_layer(
                    modeled_passage_list[-1], passage_mask)

            modeled_passage = self._dropout(modeled_passage)
            modeled_passage_list.append(modeled_passage)
            modeled_passage_output_metadata_list[
                "modeling_layer_iter_{0:03d}".format(
                    modeling_layer_id)] = modeled_passage_output_metadata

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat(
            [modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat(
            [modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            #"passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            try:
                loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, passage_mask),
                    span_start.squeeze(-1))
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(-1))
                loss += nll_loss(
                    util.masked_log_softmax(span_end_logits, passage_mask),
                    span_end.squeeze(-1))
                self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
                self._span_accuracy(best_span,
                                    torch.stack([span_start, span_end], -1))
                output_dict["loss"] = loss
            except Exception as e:
                logging.exception(e)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            metrics_per_item = None

            all_reference_answers_text = []
            all_best_spans = []

            return_metrics_per_item = True

            if not self.training:
                metrics_per_item = [{} for x in range(batch_size)]

            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())

                # offsets = metadata[i]['token_offsets']
                # start_offset = offsets[predicted_span[0]][0]
                # end_offset = offsets[predicted_span[1]][1]

                start_span = predicted_span[0]
                end_span = predicted_span[1]
                best_span_tokens = metadata[i]['passage_tokens'][
                    start_span:end_span + 1]
                best_span_string = " ".join(best_span_tokens)
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])

                all_best_spans.append(best_span_string)

                if answer_texts:
                    curr_item_em, curr_item_f1 = self._squad_metrics(
                        best_span_string, answer_texts, return_score=True)
                    if not self.training and return_metrics_per_item:
                        metrics_per_item[i]["em"] = curr_item_em
                        metrics_per_item[i]["f1"] = curr_item_f1

                    all_reference_answers_text.append(answer_texts)

            if return_output_metadata:
                output_dict["output_metadata"] = {
                    "encoded_passage": encoded_passage_output_metadata,
                    "encoded_question": encoded_question_output_metadata,
                    "modeling_layer": modeled_passage_output_metadata_list,
                }

            if not self.training and len(all_reference_answers_text) > 0:
                metrics_per_item_rouge = self.calculate_rouge(
                    all_best_spans,
                    all_reference_answers_text,
                    return_metrics_per_item=return_metrics_per_item)

                for i, curr_metrics in enumerate(metrics_per_item_rouge):
                    metrics_per_item[i].update(curr_metrics)

            if metrics_per_item is not None:
                output_dict['metrics'] = metrics_per_item

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens

        return output_dict
Esempio n. 56
0
    def forward(self,  # type: ignore
                tokens: Dict[str, torch.LongTensor],
                tags: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                # pylint: disable=unused-argument
                **kwargs) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : ``torch.LongTensor``, optional (default = ``None``)
            A torch tensor representing the sequence of integer gold class labels of shape
            ``(batch_size, num_tokens)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containg the original words in the sentence to be tagged under a 'words' key.

        Returns
        -------
        An output dictionary consisting of:

        logits : ``torch.FloatTensor``
            The logits that are the output of the ``tag_projection_layer``
        mask : ``torch.LongTensor``
            The text field mask for the input tokens
        tags : ``List[List[int]]``
            The predicted tags using the Viterbi algorithm.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
        """
        embedded_text_input = self.text_field_embedder(tokens)
        mask = util.get_text_field_mask(tokens)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, mask)

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        if self._feedforward is not None:
            encoded_text = self._feedforward(encoded_text)

        logits = self.tag_projection_layer(encoded_text)
        best_paths = self.crf.viterbi_tags(logits, mask)

        # Just get the tags and ignore the score.
        predicted_tags = [x for x, y in best_paths]

        output = {"logits": logits, "mask": mask, "tags": predicted_tags}

        if tags is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self.crf(logits, tags, mask)
            output["loss"] = -log_likelihood

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = logits * 0.
            for i, instance_tags in enumerate(predicted_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            for metric in self.metrics.values():
                metric(class_probabilities, tags, mask.float())
            if self.calculate_span_f1:
                self._f1_metric(class_probabilities, tags, mask.float())
        if metadata is not None:
            output["words"] = [x["words"] for x in metadata]
        return output
    def forward(self,
                tokens: Dict[str, torch.Tensor],
                label: torch.Tensor = None) -> torch.Tensor:
        # In deep NLP, when sequences of tensors in different lengths are batched together,
        # shorter sequences get padded with zeros to make them equal length.
        # Masking is the process to ignore extra zeros added by padding
        text_mask = util.get_text_field_mask(tokens).float()
        # Forward pass
        embedded_text = self.text_field_embedder(tokens)
        dropped_embedded_text = self.embedding_dropout(embedded_text)
        encoded_tokens = self.encoder(dropped_embedded_text, text_mask)

        # Compute biattention. This is a special case since the inputs are the same.
        attention_logits = encoded_tokens.bmm(
            encoded_tokens.permute(0, 2, 1).contiguous())
        attention_weights = util.masked_softmax(attention_logits, text_mask)
        encoded_text = util.weighted_sum(encoded_tokens, attention_weights)

        # Build the input to the integrator
        integrator_input = torch.cat([
            encoded_tokens, encoded_tokens - encoded_text,
            encoded_tokens * encoded_text
        ], 2)
        integrated_encodings = self.integrator(integrator_input, text_mask)

        # Simple Pooling layers
        max_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, text_mask.unsqueeze(2), -1e7)
        max_pool = torch.max(max_masked_integrated_encodings, 1)[0]
        min_masked_integrated_encodings = util.replace_masked_values(
            integrated_encodings, text_mask.unsqueeze(2), +1e7)
        min_pool = torch.min(min_masked_integrated_encodings, 1)[0]
        mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(
            text_mask, 1, keepdim=True)

        # Self-attentive pooling layer
        # Run through linear projection. Shape: (batch_size, sequence length, 1)
        # Then remove the last dimension to get the proper attention shape (batch_size, sequence length).
        self_attentive_logits = self._self_attentive_pooling_projection(
            integrated_encodings).squeeze(2)
        self_weights = util.masked_softmax(self_attentive_logits, text_mask)
        self_attentive_pool = util.weighted_sum(integrated_encodings,
                                                self_weights)

        pooled_representations = torch.cat(
            [max_pool, min_pool, mean_pool, self_attentive_pool], 1)
        pooled_representations_dropped = self.integrator_dropout(
            pooled_representations)

        logits = self.output_layer(pooled_representations_dropped)

        # In AllenNLP, the output of forward() is a dictionary.
        # Your output dictionary must contain a "loss" key for your model to be trained.
        output = {"logits": logits}
        if label is not None:
            self.accuracy(logits, label)
            self.f1_measure_positive(logits, label)
            self.f1_measure_negative(logits, label)
            self.f1_measure_neutral(logits, label)
            output["loss"] = self.loss_function(logits, label)

        return output
    def forward(self,  # type: ignore
                words: Dict[str, torch.LongTensor],
                pos_tags: torch.LongTensor,
                metadata: List[Dict[str, Any]],
                head_tags: torch.LongTensor = None,
                head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        words : Dict[str, torch.LongTensor], required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        pos_tags : ``torch.LongTensor``, required.
            The output of a ``SequenceLabelField`` containing POS tags.
            POS tags are required regardless of whether they are used in the model,
            because they are used to filter the evaluation metric to only consider
            heads of words which are not punctuation.
        head_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels for the arcs
            in the dependency parse. Has shape ``(batch_size, sequence_length)``.
        head_indices : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer indices denoting the parent of every
            word in the dependency parse. Has shape ``(batch_size, sequence_length)``.

        Returns
        -------
        An output dictionary consisting of:
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        arc_loss : ``torch.FloatTensor``
            The loss contribution from the unlabeled arcs.
        loss : ``torch.FloatTensor``, optional
            The loss contribution from predicting the dependency
            tags for the gold arcs.
        heads : ``torch.FloatTensor``
            The predicted head indices for each word. A tensor
            of shape (batch_size, sequence_length).
        head_types : ``torch.FloatTensor``
            The predicted head types for each arc. A tensor
            of shape (batch_size, sequence_length).
        mask : ``torch.LongTensor``
            A mask denoting the padded elements in the batch.
        """
        embedded_text_input = self.text_field_embedder(words)
        if pos_tags is not None and self._pos_tag_embedding is not None:
            embedded_pos_tags = self._pos_tag_embedding(pos_tags)
            embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1)
        elif self._pos_tag_embedding is not None:
            raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

        mask = get_text_field_mask(words)
        embedded_text_input = self._input_dropout(embedded_text_input)
        encoded_text = self.encoder(embedded_text_input, mask)

        batch_size, _, encoding_dim = encoded_text.size()

        head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
        # Concatenate the head sentinel onto the sentence representation.
        encoded_text = torch.cat([head_sentinel, encoded_text], 1)
        mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1)
        if head_indices is not None:
            head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1)
        if head_tags is not None:
            head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1)
        float_mask = mask.float()
        encoded_text = self._dropout(encoded_text)

        # shape (batch_size, sequence_length, arc_representation_dim)
        head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text))
        child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text))

        # shape (batch_size, sequence_length, tag_representation_dim)
        head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text))
        child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text))
        # shape (batch_size, sequence_length, sequence_length)
        attended_arcs = self.arc_attention(head_arc_representation,
                                           child_arc_representation)

        minus_inf = -1e8
        minus_mask = (1 - float_mask) * minus_inf
        attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1)

        if self.training or not self.use_mst_decoding_for_validation:
            predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation,
                                                                       child_tag_representation,
                                                                       attended_arcs,
                                                                       mask)
        else:
            predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation,
                                                                    child_tag_representation,
                                                                    attended_arcs,
                                                                    mask)
        if head_indices is not None and head_tags is not None:

            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=head_indices,
                                                    head_tags=head_tags,
                                                    mask=mask)
            loss = arc_nll + tag_nll

            evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
            # We calculate attatchment scores for the whole sentence
            # but excluding the symbolic ROOT token at the start,
            # which is why we start from the second element in the sequence.
            self._attachment_scores(predicted_heads[:, 1:],
                                    predicted_head_tags[:, 1:],
                                    head_indices[:, 1:],
                                    head_tags[:, 1:],
                                    evaluation_mask)
        else:
            arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
                                                    child_tag_representation=child_tag_representation,
                                                    attended_arcs=attended_arcs,
                                                    head_indices=predicted_heads.long(),
                                                    head_tags=predicted_head_tags.long(),
                                                    mask=mask)
            loss = arc_nll + tag_nll

        output_dict = {
                "heads": predicted_heads,
                "head_tags": predicted_head_tags,
                "arc_loss": arc_nll,
                "tag_loss": tag_nll,
                "loss": loss,
                "mask": mask,
                "words": [meta["words"] for meta in metadata],
                "pos": [meta["pos"] for meta in metadata]
                }

        return output_dict
Esempio n. 59
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, num_of_passage_tokens = passage['bert'].size()

        # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of
        # start and end spans.
        embedded_passage = self._text_field_embedder(passage)
        passage_length = embedded_passage.size(1)
        logits = self.qa_outputs(embedded_passage)
        start_logits, end_logits = logits.split(1, dim=-1)
        span_start_logits = start_logits.squeeze(-1)
        span_end_logits = end_logits.squeeze(-1)

        # Adding some masks with numerically stable values
        passage_mask = util.get_text_field_mask(passage).float()
        repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1)
        repeated_passage_mask = repeated_passage_mask.view(batch_size, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7)


        output_dict: Dict[str, Any] = {}

        # We may have multiple instances per questions, moving to per-question
        intances_question_id = [insta_meta['question_id'] for insta_meta in metadata]
        question_instances_split_inds = np.cumsum(np.unique(intances_question_id, return_counts=True)[1])[:-1]
        per_question_inds = np.split(range(batch_size), question_instances_split_inds)
        metadata = np.split(metadata, question_instances_split_inds)

        # Compute the loss.
        if span_start is not None and len(np.argwhere(span_start.squeeze().cpu() >= 0)) > 0:
            # in evaluation some instances may not contain the gold answer, so we need to compute
            # loss only on those that do.
            inds_with_gold_answer = np.argwhere(span_start.view(-1).cpu().numpy() >= 0)
            inds_with_gold_answer = inds_with_gold_answer.squeeze() if len(inds_with_gold_answer) > 1 else inds_with_gold_answer
            if len(inds_with_gold_answer)>0:
                loss = nll_loss(util.masked_log_softmax(span_start_logits[inds_with_gold_answer], \
                                                    repeated_passage_mask[inds_with_gold_answer]),\
                                span_start.view(-1)[inds_with_gold_answer], ignore_index=-1)
                loss += nll_loss(util.masked_log_softmax(span_end_logits[inds_with_gold_answer], \
                                                    repeated_passage_mask[inds_with_gold_answer]),\
                                span_end.view(-1)[inds_with_gold_answer], ignore_index=-1)
                output_dict["loss"] = loss

        # This is a hack for cases in which gold answer is not provided so we cannot compute loss...
        if 'loss' not in output_dict:
            output_dict["loss"] = torch.cuda.FloatTensor([0], device=span_end_logits.device) \
                if torch.cuda.is_available() else torch.FloatTensor([0])

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['qid'] = []

        # getting best span prediction for
        best_span = self._get_example_predications(span_start_logits, span_end_logits, self._max_span_length)
        best_span_cpu = best_span.detach().cpu().numpy()

        span_start_logits_numpy = span_start_logits.data.cpu().numpy()
        span_end_logits_numpy = span_end_logits.data.cpu().numpy()
        # Iterating over every question (which may contain multiple instances, one per chunk)
        for question_inds, question_instances_metadata in zip(per_question_inds, metadata):
            best_span_ind = np.argmax(span_start_logits_numpy[question_inds, best_span_cpu[question_inds][:, 0]] +
                      span_end_logits_numpy[question_inds, best_span_cpu[question_inds][:, 1]])
            best_span_logit = np.max(span_start_logits_numpy[question_inds, best_span_cpu[question_inds][:, 0]] +
                                      span_end_logits_numpy[question_inds, best_span_cpu[question_inds][:, 1]])

            passage_str = question_instances_metadata[best_span_ind]['original_passage']
            offsets = question_instances_metadata[best_span_ind]['token_offsets']

            predicted_span = best_span_cpu[question_inds[best_span_ind]]
            start_offset = offsets[predicted_span[0]][0]
            end_offset = offsets[predicted_span[1]][1]
            best_span_string = passage_str[start_offset:end_offset]

            # Note: this is a hack, because AllenNLP, when predicting, expects a value for each instance.
            # But we may have more than 1 chunk per question, and thus less output strings than instances
            for i in range(len(question_inds)):
                output_dict['best_span_str'].append(best_span_string)
                output_dict['qid'].append(question_instances_metadata[best_span_ind]['question_id'])

            f1_score = 0.0
            EM_score = 0.0
            gold_answer_texts = question_instances_metadata[best_span_ind]['answer_texts_list']
            if gold_answer_texts:
                f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,best_span_string,gold_answer_texts)
                EM_score = squad_eval.metric_max_over_ground_truths(squad_eval.exact_match_score, best_span_string,gold_answer_texts)
            self._official_f1(100 * f1_score)
            self._official_EM(100 * EM_score)

            # TODO move to predict
            if self._predictions_file is not None:
                with open(self._predictions_file,'a') as f:
                    f.write(json.dumps({'question_id':question_instances_metadata[best_span_ind]['question_id'], \
                                'best_span_logit':float(best_span_logit), \
                                'f1':100 * f1_score,
                                'EM':100 * EM_score,
                                'best_span_string':best_span_string,\
                                'gold_answer_texts':gold_answer_texts, \
                                'qas_used_fraction':1.0}) + '\n')

        return output_dict