def forward(self, s1, s2): # pylint: disable=arguments-differ """ """ # Embeddings s1_embs = self._highway_layer(self._text_field_embedder(s1)) s2_embs = self._highway_layer(self._text_field_embedder(s2)) if self._elmo is not None: s1_elmo_embs = self._elmo(s1['elmo']) s2_elmo_embs = self._elmo(s2['elmo']) if "words" in s1: s1_embs = torch.cat([s1_embs, s1_elmo_embs['elmo_representations'][0]], dim=-1) s2_embs = torch.cat([s2_embs, s2_elmo_embs['elmo_representations'][0]], dim=-1) else: s1_embs = s1_elmo_embs['elmo_representations'][0] s2_embs = s2_elmo_embs['elmo_representations'][0] if self._cove is not None: s1_lens = torch.ne(s1['words'], self.pad_idx).long().sum(dim=-1).data s2_lens = torch.ne(s2['words'], self.pad_idx).long().sum(dim=-1).data s1_cove_embs = self._cove(s1['words'], s1_lens) s1_embs = torch.cat([s1_embs, s1_cove_embs], dim=-1) s2_cove_embs = self._cove(s2['words'], s2_lens) s2_embs = torch.cat([s2_embs, s2_cove_embs], dim=-1) s1_embs = self._dropout(s1_embs) s2_embs = self._dropout(s2_embs) # Set up masks s1_mask = util.get_text_field_mask(s1) s2_mask = util.get_text_field_mask(s2) s1_lstm_mask = s1_mask.float() if self._mask_lstms else None s2_lstm_mask = s2_mask.float() if self._mask_lstms else None # Sentence encodings with LSTMs s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask) s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask) if self._elmo is not None and len(s1_elmo_embs['elmo_representations']) > 1: s1_enc = torch.cat([s1_enc, s1_elmo_embs['elmo_representations'][1]], dim=-1) s2_enc = torch.cat([s2_enc, s2_elmo_embs['elmo_representations'][1]], dim=-1) s1_enc = self._dropout(s1_enc) s2_enc = self._dropout(s2_enc) # Max pooling s1_mask = s1_mask.unsqueeze(dim=-1) s2_mask = s2_mask.unsqueeze(dim=-1) s1_enc.data.masked_fill_(1 - s1_mask.byte().data, -float('inf')) s2_enc.data.masked_fill_(1 - s2_mask.byte().data, -float('inf')) s1_enc, _ = s1_enc.max(dim=1) s2_enc, _ = s2_enc.max(dim=1) return torch.cat([s1_enc, s2_enc, torch.abs(s1_enc - s2_enc), s1_enc * s2_enc], 1)
def test_get_text_field_mask_returns_mask_key(self): text_field_tensors = { "tokens": torch.LongTensor([[3, 4, 5, 0, 0], [1, 2, 0, 0, 0]]), "mask": torch.LongTensor([[0, 0, 1]]) } assert_almost_equal(util.get_text_field_mask(text_field_tensors).numpy(), [[0, 0, 1]])
def forward(self, sent): # pylint: disable=arguments-differ """ Parameters ---------- sent : Dict[str, torch.LongTensor] From a ``TextField``. Returns ------- """ sent_embs = self._highway_layer(self._text_field_embedder(sent)) if self._cove is not None: sent_lens = torch.ne(sent['words'], self.pad_idx).long().sum(dim=-1).data sent_cove_embs = self._cove(sent['words'], sent_lens) sent_embs = torch.cat([sent_embs, sent_cove_embs], dim=-1) if self._elmo is not None: elmo_embs = self._elmo(sent['elmo']) if "words" in sent: sent_embs = torch.cat([sent_embs, elmo_embs['elmo_representations'][0]], dim=-1) else: sent_embs = elmo_embs['elmo_representations'][0] sent_embs = self._dropout(sent_embs) sent_mask = util.get_text_field_mask(sent).float() sent_lstm_mask = sent_mask if self._mask_lstms else None sent_enc = self._phrase_layer(sent_embs, sent_lstm_mask) if self._elmo is not None and len(elmo_embs['elmo_representations']) > 1: sent_enc = torch.cat([sent_enc, elmo_embs['elmo_representations'][1]], dim=-1) sent_enc = self._dropout(sent_enc) sent_mask = sent_mask.unsqueeze(dim=-1) sent_enc.data.masked_fill_(1 - sent_mask.byte().data, -float('inf')) return sent_enc.max(dim=1)[0]
def test_get_text_field_mask_returns_a_correct_mask_character_only_input(self): text_field_tensors = { "token_characters": torch.LongTensor([[[1, 2, 3], [3, 0, 1], [2, 1, 0], [0, 0, 0]], [[5, 5, 5], [4, 6, 0], [0, 0, 0], [0, 0, 0]]]) } assert_almost_equal(util.get_text_field_mask(text_field_tensors).numpy(), [[1, 1, 1, 0], [1, 1, 0, 0]])
def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]): embedded_input = self._sentence_embedder(sentence) # (batch_size, sentence_length) sentence_mask = util.get_text_field_mask(sentence).float() batch_size = embedded_input.size(0) # (batch_size, sentence_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_input, sentence_mask)) final_encoder_output = util.get_final_encoder_states(encoder_outputs, sentence_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) attended_sentence, _ = self._decoder_step.attend_on_question(final_encoder_output, encoder_outputs, sentence_mask) encoder_outputs_list = [encoder_outputs[i] for i in range(batch_size)] sentence_mask_list = [sentence_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, attended_sentence[i], encoder_outputs_list, sentence_mask_list)) return initial_rnn_state
def test_get_text_field_mask_returns_a_correct_mask_list_field(self): text_field_tensors = { "list_tokens": torch.LongTensor([[[1, 2], [3, 0], [2, 0], [0, 0], [0, 0]], [[5, 0], [4, 6], [0, 0], [0, 0], [0, 0]]]) } actual_mask = util.get_text_field_mask(text_field_tensors, num_wrapping_dims=1).numpy() expected_mask = (text_field_tensors['list_tokens'].numpy() > 0).astype('int32') assert_almost_equal(actual_mask, expected_mask)
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], tags: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) metadata containg the original words in the sentence to be tagged under a 'words' key. Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text_input = self.text_field_embedder(tokens) batch_size, sequence_length, _ = embedded_text_input.size() mask = get_text_field_mask(tokens) encoded_text = self.encoder(embedded_text_input, mask) logits = self.tag_projection_layer(encoded_text) reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view([batch_size, sequence_length, self.num_classes]) output_dict = {"logits": logits, "class_probabilities": class_probabilities} if tags is not None: loss = sequence_cross_entropy_with_logits(logits, tags, mask) for metric in self.metrics.values(): metric(logits, tags, mask.float()) output_dict["loss"] = loss if metadata is not None: output_dict["words"] = [x["words"] for x in metadata] return output_dict
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. tags : ``torch.LongTensor``, optional (default = ``None``) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)``. Returns ------- An output dictionary consisting of: logits : ``torch.FloatTensor`` The logits that are the output of the ``tag_projection_layer`` mask : ``torch.LongTensor`` The text field mask for the input tokens tags : ``List[List[str]]`` The predicted tags using the Viterbi algorithm. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. Only computed if gold label ``tags`` are provided. """ embedded_text_input = self.text_field_embedder(tokens) mask = util.get_text_field_mask(tokens) encoded_text = self.encoder(embedded_text_input, mask) logits = self.tag_projection_layer(encoded_text) predicted_tags = self.crf.viterbi_tags(logits, mask) output = {"logits": logits, "mask": mask, "tags": predicted_tags} if tags is not None: # Add negative log-likelihood as loss log_likelihood = self.crf(logits, tags, mask) output["loss"] = -log_likelihood # Represent viterbi tags as "class probabilities" that we can # feed into the `span_metric` class_probabilities = logits * 0. for i, instance_tags in enumerate(predicted_tags): for j, tag_id in enumerate(instance_tags): class_probabilities[i, j, tag_id] = 1 self.span_metric(class_probabilities, tags, mask) return output
def forward(self, sentence: Dict[str, torch.Tensor], labels: torch.Tensor = None) -> torch.Tensor: mask = get_text_field_mask(sentence) embeddings = self.word_embeddings(sentence) encoder_out = self.encoder(embeddings, mask) tag_logits = self.hidden2tag(encoder_out) output = {"tag_logits": tag_logits} if labels is not None: self.accuracy(tag_logits, labels, mask) output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask) return output
def _update_recall(self, all_top_k_predictions: torch.Tensor, target_tokens: Dict[str, torch.LongTensor], target_recall: UnigramRecall) -> None: targets = target_tokens["tokens"] target_mask = get_text_field_mask(target_tokens) # See comment in _get_loss. # TODO(brendanr): Do we need contiguous here? relevant_targets = targets[:, 1:].contiguous() relevant_mask = target_mask[:, 1:].contiguous() target_recall( all_top_k_predictions, relevant_targets, relevant_mask, self._end_index )
def get_sample_encoded_output(self): """ Returns the encoded vector for a sample event. """ for instance in self.dataset: cur_text_field = instance.fields["source"] text = [token.text for token in cur_text_field.tokens] if text == ["@start@", "personx", "calls", "personx", "'s", "brother", "@end@"]: sample_text_field = cur_text_field break source = sample_text_field.as_tensor(sample_text_field.get_padding_lengths()) source['tokens'] = source['tokens'].unsqueeze(0) embedded_input = self.trained_model._embedding_dropout( self.trained_model._source_embedder(source) ) source_mask = get_text_field_mask(source) return self.trained_model._encoder(embedded_input, source_mask)
def forward(self, # type: ignore inputs: torch.Tensor) -> Dict[str, torch.Tensor]: """ Parameters ---------- inputs: ``torch.Tensor`` Shape ``(batch_size, timesteps, ...)`` of token ids representing the current batch. These must have been produced using the same indexer the LM was trained on. Returns ------- The bidirectional language model representations for the input sequence, shape ``(batch_size, timesteps, embedding_dim)`` """ # pylint: disable=arguments-differ if self._bos_indices is not None: mask = get_text_field_mask({"": inputs}) inputs, mask = add_sentence_boundary_token_ids( inputs, mask, self._bos_indices, self._eos_indices ) source = {self._token_name: inputs} result_dict = self._lm(source) # shape (batch_size, timesteps, embedding_size) noncontextual_token_embeddings = result_dict["noncontextual_token_embeddings"] contextual_embeddings = result_dict["lm_embeddings"] # Typically the non-contextual embeddings are smaller than the contextualized embeddings. # Since we're averaging all the layers we need to make their dimensions match. Simply # repeating the non-contextual embeddings is a crude, but effective, way to do this. duplicated_character_embeddings = torch.cat( [noncontextual_token_embeddings] * self._character_embedding_duplication_count, -1 ) averaged_embeddings = self._scalar_mix( [duplicated_character_embeddings] + contextual_embeddings ) # Add dropout averaged_embeddings = self._dropout(averaged_embeddings) if self._remove_bos_eos: averaged_embeddings, _ = remove_sentence_boundaries( averaged_embeddings, result_dict["mask"] ) return averaged_embeddings
def forward(self, sentence: Dict[str, torch.Tensor], labels: torch.Tensor = None) -> torch.Tensor: #### AllenNLP is designed to operate on batched inputs, but different input sequences have different lengths. Behind the scenes AllenNLP is padding the shorter inputs so that the batch has uniform shape, which means our computations need to use a mask to exclude the padding. Here we just use the utility function <code>get_text_field_mask</code>, which returns a tensor of 0s and 1s corresponding to the padded and unpadded locations. mask = get_text_field_mask(sentence) #### We start by passing the <code>sentence</code> tensor (each sentence a sequence of token ids) to the <code>word_embeddings</code> module, which converts each sentence into a sequence of embedded tensors. embeddings = self.word_embeddings(sentence) #### We next pass the embedded tensors (and the mask) to the LSTM, which produces a sequence of encoded outputs. encoder_out = self.encoder(embeddings, mask) #### Finally, we pass each encoded output tensor to the feedforward layer to produce logits corresponding to the various tags. tag_logits = self.hidden2tag(encoder_out) output = {"tag_logits": tag_logits} #### As before, the labels were optional, as we might want to run this model to make predictions on unlabeled data. If we do have labels, then we use them to update our accuracy metric and compute the "loss" that goes in our output. if labels is not None: self.accuracy(tag_logits, labels, mask) output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask) return output
def greedy_search(self, final_encoder_output: torch.LongTensor, target_tokens: Dict[str, torch.LongTensor], target_embedder: Embedding, decoder_cell: GRUCell, output_projection_layer: Linear) -> torch.FloatTensor: """ Greedily produces a sequence using the provided ``decoder_cell``. Returns the cross entropy between this sequence and ``target_tokens``. Parameters ---------- final_encoder_output : ``torch.LongTensor``, required Vector produced by ``self._encoder``. target_tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()`` applied on some target ``TextField``. target_embedder : ``Embedding``, required Used to embed the target tokens. decoder_cell: ``GRUCell``, required The recurrent cell used at each time step. output_projection_layer: ``Linear``, required Linear layer mapping to the desired number of classes. """ num_decoding_steps = self._get_num_decoding_steps(target_tokens) targets = target_tokens["tokens"] decoder_hidden = final_encoder_output step_logits = [] for timestep in range(num_decoding_steps): # See https://github.com/allenai/allennlp/issues/1134. input_choices = targets[:, timestep] decoder_input = target_embedder(input_choices) decoder_hidden = decoder_cell(decoder_input, decoder_hidden) # (batch_size, num_classes) output_projections = output_projection_layer(decoder_hidden) # list of (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) target_mask = get_text_field_mask(target_tokens) return self._get_loss(logits, targets, target_mask)
def forward(self, question): # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. Returns ------- pair_rep : torch.FloatTensor? Tensor representing the final output of the BiDAF model to be plugged into the next module """ word_char_embs = self._text_field_embedder(question) question_mask = util.get_text_field_mask(question).float() return word_char_embs.mean(1) # need to get # nonzero elts
def _get_initial_state_and_scores(self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRuleArray]], example_lisp_string: List[str] = None, add_world_to_initial_state: bool = False, checklist_states: List[ChecklistState] = None) -> Dict: """ Does initial preparation and creates an intiial state for both the semantic parsers. Note that the checklist state is optional, and the ``WikiTablesMmlParser`` is not expected to pass it. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: one-hot tensor with shape (batch_size, num_entities, num_types) # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._type_params(entity_types.float()) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.nn.functional.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_question.data.new_zeros(batch_size) action_embeddings, output_action_embeddings, action_biases, action_indices = self._embed_actions(actions) _, num_entities, num_question_tokens = linking_scores.size() flattened_linking_scores, actions_to_entities = self._map_entity_productions(linking_scores, world, actions) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnState(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i]) for i in range(batch_size)] initial_state_world = world if add_world_to_initial_state else None initial_state = WikiTablesDecoderState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, action_embeddings=action_embeddings, output_action_embeddings=output_action_embeddings, action_biases=action_biases, action_indices=action_indices, possible_actions=actions, flattened_linking_scores=flattened_linking_scores, actions_to_entities=actions_to_entities, entity_types=entity_type_dict, world=initial_state_world, example_lisp_string=example_lisp_string, checklist_state=checklist_states, debug_info=None) return {"initial_state": initial_state, "linking_scores": linking_scores, "feature_scores": feature_scores, "similarity_scores": question_entity_similarity_max_score}
def forward(self, s1, s2): # pylint: disable=arguments-differ """ Parameters ---------- s1 : Dict[str, torch.LongTensor] From a ``TextField``. s2 : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this s2 contains the answer to the s1, and predicts the beginning and ending positions of the answer within the s2. Returns ------- pair_rep : torch.FloatTensor? Tensor representing the final output of the BiDAF model to be plugged into the next module """ s1_embs = self._highway_layer(self._text_field_embedder(s1)) s2_embs = self._highway_layer(self._text_field_embedder(s2)) if self._elmo is not None: s1_elmo_embs = self._elmo(s1['elmo']) s2_elmo_embs = self._elmo(s2['elmo']) if "words" in s1: s1_embs = torch.cat([s1_embs, s1_elmo_embs['elmo_representations'][0]], dim=-1) s2_embs = torch.cat([s2_embs, s2_elmo_embs['elmo_representations'][0]], dim=-1) else: s1_embs = s1_elmo_embs['elmo_representations'][0] s2_embs = s2_elmo_embs['elmo_representations'][0] if self._cove is not None: s1_lens = torch.ne(s1['words'], self.pad_idx).long().sum(dim=-1).data s2_lens = torch.ne(s2['words'], self.pad_idx).long().sum(dim=-1).data s1_cove_embs = self._cove(s1['words'], s1_lens) s1_embs = torch.cat([s1_embs, s1_cove_embs], dim=-1) s2_cove_embs = self._cove(s2['words'], s2_lens) s2_embs = torch.cat([s2_embs, s2_cove_embs], dim=-1) s1_embs = self._dropout(s1_embs) s2_embs = self._dropout(s2_embs) if self._mask_lstms: s1_mask = s1_lstm_mask = util.get_text_field_mask(s1).float() s2_mask = s2_lstm_mask = util.get_text_field_mask(s2).float() s1_mask_2 = util.get_text_field_mask(s1).float() s2_mask_2 = util.get_text_field_mask(s2).float() else: s1_lstm_mask, s2_lstm_mask, s2_lstm_mask_2 = None, None, None s1_enc = self._phrase_layer(s1_embs, s1_lstm_mask) s2_enc = self._phrase_layer(s2_embs, s2_lstm_mask) # Similarity matrix # Shape: (batch_size, s2_length, s1_length) similarity_mat = self._matrix_attention(s2_enc, s1_enc) # s2 representation # Shape: (batch_size, s2_length, s1_length) s2_s1_attention = util.last_dim_softmax(similarity_mat, s1_mask) # Shape: (batch_size, s2_length, encoding_dim) s2_s1_vectors = util.weighted_sum(s1_enc, s2_s1_attention) # batch_size, seq_len, 4*enc_dim s2_w_context = torch.cat([s2_enc, s2_s1_vectors], 2) # s1 representation, using same attn method as for the s2 representation s1_s2_attention = util.last_dim_softmax(similarity_mat.transpose(1, 2).contiguous(), s2_mask) # Shape: (batch_size, s1_length, encoding_dim) s1_s2_vectors = util.weighted_sum(s2_enc, s1_s2_attention) s1_w_context = torch.cat([s1_enc, s1_s2_vectors], 2) if self._elmo is not None and self._deep_elmo: s1_w_context = torch.cat([s1_w_context, s1_elmo_embs['elmo_representations'][1]], dim=-1) s2_w_context = torch.cat([s2_w_context, s2_elmo_embs['elmo_representations'][1]], dim=-1) s1_w_context = self._dropout(s1_w_context) s2_w_context = self._dropout(s2_w_context) modeled_s2 = self._dropout(self._modeling_layer(s2_w_context, s2_lstm_mask)) s2_mask_2 = s2_mask_2.unsqueeze(dim=-1) modeled_s2.data.masked_fill_(1 - s2_mask_2.byte().data, -float('inf')) s2_enc_attn = modeled_s2.max(dim=1)[0] modeled_s1 = self._dropout(self._modeling_layer(s1_w_context, s1_lstm_mask)) s1_mask_2 = s1_mask_2.unsqueeze(dim=-1) modeled_s1.data.masked_fill_(1 - s1_mask_2.byte().data, -float('inf')) s1_enc_attn = modeled_s1.max(dim=1)[0] return torch.cat([s1_enc_attn, s2_enc_attn, torch.abs(s1_enc_attn - s2_enc_attn), s1_enc_attn * s2_enc_attn], 1)
def forward( self, # type: ignore words: TextFieldTensors, pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None, ) -> Dict[str, torch.Tensor]: """ # Parameters words : TextFieldTensors, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer` tensors. At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens": Tensor(batch_size, sequence_length)}`. This dictionary will have the same keys as were used for the `TokenIndexers` when you created the `TextField` representing your sequence. The dictionary is designed to be passed directly to a `TextFieldEmbedder`, which knows how to combine different word representations into a single vector per token in your input. pos_tags : `torch.LongTensor`, required The output of a `SequenceLabelField` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. metadata : List[Dict[str, Any]], optional (default=None) A dictionary of metadata for each batch element which has keys: words : `List[str]`, required. The tokens in the original sentence. pos : `List[str]`, required. The dependencies POS tags for each word. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape `(batch_size, sequence_length)`. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape `(batch_size, sequence_length)`. # Returns An output dictionary consisting of: loss : `torch.FloatTensor`, optional A scalar loss to be optimised. arc_loss : `torch.FloatTensor` The loss contribution from the unlabeled arcs. loss : `torch.FloatTensor`, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : `torch.FloatTensor` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : `torch.FloatTensor` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : `torch.BoolTensor` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat( [embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError( "Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll = self._parse( embedded_text_input, mask, head_tags, head_indices) loss = arc_nll + tag_nll if head_indices is not None and head_tags is not None: evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores( predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices, head_tags, evaluation_mask, ) output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata], } return output_dict
def _get_initial_state(self, utterance: Dict[str, torch.LongTensor], worlds: List[AtisWorld], actions: List[List[ProductionRule]], linking_scores: torch.Tensor) -> GrammarBasedState: embedded_utterance = self._utterance_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance).float() batch_size = embedded_utterance.size(0) num_entities = max([len(world.entities) for world in worlds]) # entity_types: tensor with shape (batch_size, num_entities) entity_types, _ = self._get_type_vector(worlds, num_entities, embedded_utterance) # (batch_size, num_utterance_tokens, embedding_dim) encoder_input = embedded_utterance # (batch_size, utterance_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, utterance_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, utterance_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) initial_score = embedded_utterance.data.new_zeros(batch_size) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, utterance_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(utterance_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. initial_score_list = [initial_score[i] for i in range(batch_size)] encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] utterance_mask_list = [utterance_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): if self._decoder_num_layers > 1: initial_rnn_state.append(RnnStatelet(final_encoder_output[i].repeat(self._decoder_num_layers, 1), memory_cell[i].repeat(self._decoder_num_layers, 1), self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) else: initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_utterance, encoder_output_list, utterance_mask_list)) initial_grammar_state = [self._create_grammar_state(worlds[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] initial_state = GrammarBasedState(batch_indices=list(range(batch_size)), action_history=[[] for _ in range(batch_size)], score=initial_score_list, rnn_state=initial_rnn_state, grammar_state=initial_grammar_state, possible_actions=actions, debug_info=None) return initial_state
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) # # v5: # # remember to set token embeddings in the CONFIG JSON # encoded_question = self._dropout(embedded_question) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) -- SIMILARITY MATRIX similarity_matrix = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) -- CONTEXT2QUERY passage_question_attention = util.last_dim_softmax( similarity_matrix, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Our custom query2context q2c_attention = util.masked_softmax(similarity_matrix, question_mask, dim=1).transpose(-1, -2) q2c_vecs = util.weighted_sum(encoded_passage, q2c_attention) # Now we try the various variants # v1: # tiled_question_passage_vector = util.weighted_sum(q2c_vecs, passage_question_attention) # v2: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], encoded_passage.shape[1])) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).transpose(-1, -2) # # v3: # q2c_compressor = TimeDistributed(torch.nn.Linear(q2c_vecs.shape[1], 1)) # tiled_question_passage_vector = q2c_compressor(q2c_vecs.transpose(-1, -2)).squeeze().unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # v4: # Re-application of query2context attention # new_similarity_matrix = self._matrix_attention(encoded_passage, q2c_vecs) # masked_similarity = util.replace_masked_values(new_similarity_matrix, # question_mask.unsqueeze(1), # -1e7) # # Shape: (batch_size, passage_length) # question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # # Shape: (batch_size, passage_length) # question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # # Shape: (batch_size, encoding_dim) # question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # # Shape: (batch_size, passage_length, encoding_dim) # tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, # passage_length, # encoding_dim) # # # ------- Original variant # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( similarity_matrix, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # ------- END # Shape: (batch_size, passage_length, encoding_dim * 4) # original beta combination function # final_merged_passage = torch.cat([encoded_passage, # passage_question_vectors, # encoded_passage * passage_question_vectors, # encoded_passage * tiled_question_passage_vector], # dim=-1) # # v6: # final_merged_passage = torch.cat([tiled_question_passage_vector], # dim=-1) # # # v7: # final_merged_passage = torch.cat([passage_question_vectors], # dim=-1) # # # v8: # final_merged_passage = torch.cat([passage_question_vectors, # tiled_question_passage_vector], # dim=-1) # # v9: final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors ], dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore tokens: TextFieldTensors, label: torch.LongTensor = None, ) -> Dict[str, torch.Tensor]: """ # Parameters tokens : TextFieldTensors, required The output of `TextField.as_array()`. label : torch.LongTensor, optional (default = None) A variable representing the label for each instance in the batch. # Returns An output dictionary consisting of: - `class_probabilities` (`torch.FloatTensor`) : A tensor of shape `(batch_size, num_classes)` representing a distribution over the label classes for each instance. - `loss` (`torch.FloatTensor`, optional) : A scalar loss to be optimised. """ text_mask = util.get_text_field_mask(tokens) # Pop elmo tokens, since elmo embedder should not be present. elmo_tokens = tokens.pop("elmo", None) if tokens: embedded_text = self._text_field_embedder(tokens) else: # only using "elmo" for input embedded_text = None # Add the "elmo" key back to "tokens" if not None, since the tests and the # subsequent training epochs rely not being modified during forward() if elmo_tokens is not None: tokens["elmo"] = elmo_tokens # Create ELMo embeddings if applicable if self._elmo: if elmo_tokens is not None: elmo_representations = self._elmo( elmo_tokens["elmo_tokens"])["elmo_representations"] # Pop from the end is more performant with list if self._use_integrator_output_elmo: integrator_output_elmo = elmo_representations.pop() if self._use_input_elmo: input_elmo = elmo_representations.pop() assert not elmo_representations else: raise ConfigurationError( "Model was built to use Elmo, but input text is not tokenized for Elmo." ) if self._use_input_elmo: if embedded_text is not None: embedded_text = torch.cat([embedded_text, input_elmo], dim=-1) else: embedded_text = input_elmo dropped_embedded_text = self._embedding_dropout(embedded_text) pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text) encoded_tokens = self._encoder(pre_encoded_text, text_mask) # Compute biattention. This is a special case since the inputs are the same. attention_logits = encoded_tokens.bmm( encoded_tokens.permute(0, 2, 1).contiguous()) attention_weights = util.masked_softmax(attention_logits, text_mask) encoded_text = util.weighted_sum(encoded_tokens, attention_weights) # Build the input to the integrator integrator_input = torch.cat([ encoded_tokens, encoded_tokens - encoded_text, encoded_tokens * encoded_text ], 2) integrated_encodings = self._integrator(integrator_input, text_mask) # Concatenate ELMo representations to integrated_encodings if specified if self._use_integrator_output_elmo: integrated_encodings = torch.cat( [integrated_encodings, integrator_output_elmo], dim=-1) # Simple Pooling layers max_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), util.min_value_of_dtype(integrated_encodings.dtype), ) max_pool = torch.max(max_masked_integrated_encodings, 1)[0] min_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), util.max_value_of_dtype(integrated_encodings.dtype), ) min_pool = torch.min(min_masked_integrated_encodings, 1)[0] mean_pool = torch.sum(integrated_encodings, 1) / torch.sum( text_mask, 1, keepdim=True) # Self-attentive pooling layer # Run through linear projection. Shape: (batch_size, sequence length, 1) # Then remove the last dimension to get the proper attention shape (batch_size, sequence length). self_attentive_logits = self._self_attentive_pooling_projection( integrated_encodings).squeeze(2) self_weights = util.masked_softmax(self_attentive_logits, text_mask) self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights) pooled_representations = torch.cat( [max_pool, min_pool, mean_pool, self_attentive_pool], 1) pooled_representations_dropped = self._integrator_dropout( pooled_representations) logits = self._output_layer(pooled_representations_dropped) class_probabilities = F.softmax(logits, dim=-1) output_dict = { "logits": logits, "class_probabilities": class_probabilities } if label is not None: loss = self.loss(logits, label) for metric in self.metrics.values(): metric(logits, label) output_dict["loss"] = loss return output_dict
def forward(self, tags, history, next_sym, source_tokens, his_symptoms, target_tokens, **args): bs = len(tags) # self.flatten_parameters() # 获取history的embedding embeddings = self._source_embedder(history) mask = get_text_field_mask(history, num_wrapping_dims=1) # num_wrapping 增加维度 sz = list(embeddings.size()) embeddings = embeddings.view(sz[0] * sz[1], sz[2], sz[3]) mask = mask.view(sz[0] * sz[1], sz[2]) # 获取每一句的hidden bs * sen_num * hidden utter_hidden = self._vecoder(embeddings, mask) utter_hidden = utter_hidden.view(sz[0], sz[1], -1) # bs * sen_num * hidden dialog_mask = get_text_field_mask(history) dialog_hidden = self._sen_encoder(utter_hidden, dialog_mask) # hred的形式 # print("dialog_hidden: ",dialog_hidden.size()) # 初始化每个节点 symp_state = torch.zeros( bs, self.symp_size, self.outfeature).cuda() # bs * symp_size * hidden symp_state += self.symp_state # 每一个节点的初始化emb相同,这是个问题吗? # 句子与句子连边 如果不用cuda呢 sym_mat = torch.zeros(bs, self.symp_size, self.symp_size) for i in range(bs): dic = {} for j in range(len(tags[i])): # 这一个地方可以改一下,这里是和前面的所有有关系的边都连上了 symp_state[i][self.topic + j] = utter_hidden[i][j] dic[j] = set(list(tags[i][j])) for k in range(j): for aa in dic[j]: if aa in dic[k] and aa != -1: # sym_mat[i][self.topic+j][self.topic+k] += 1 sym_mat[i][self.topic + k][self.topic + j] += 1 for i in range(bs): for j in range(len(tags[i])): symp_state[i][self.topic + j] = utter_hidden[i][j] last = min(j + self.last_sen, len(tags[i])) sym_mat[i][j][j:last] = torch.ones(1) # 和后面几条边相连 last_h = self.attn_one(symp_state, sym_mat) sym_mat = torch.zeros(bs, self.symp_size, self.symp_size) for i in range(bs): for j in range(len(tags[i])): for tt in tags[i][j]: if tt != -1: sym_mat[i][self.topic + j][tt] += 1 # last_h = self.attn_two(last_h, sym_mat) # # topic和topic连边 sym_mat = torch.zeros(bs, self.symp_size, self.symp_size) for symp_i in his_symptoms: for symp_j in his_symptoms: self.evovl_mat[symp_i][symp_j] = 1 temp_mat = (torch.nn.functional.relu(self.symp_mat) + self.evovl_mat).cpu() with open('visulize_graph.txt', 'a') as fout: fout.write('evovl_mat is: \n') for i in self.evovl_mat.detach().cpu().numpy(): fout.write(str(i) + '\n') fout.write('temp_mat is: \n') for i in temp_mat.detach().cpu().numpy(): fout.write(str(i) + '\n') # print('[info] temp_mat is:{}'.format(temp_mat)) sym_mat[:, :self.topic, :self.topic] += temp_mat # sym_mat[:, :self.topic, :self.topic] += self.symp_mat # last_h = self.attn_two(symp_state, sym_mat) last_h = self.attn_three(last_h, sym_mat) topic_pre = torch.sum(self.predict_layer * last_h, dim=-1) + self.predict_bias topic_probs = torch.sigmoid(topic_pre) topics_weight = torch.ones_like(topic_probs) + 5 * next_sym.float() topic_loss = torch.nn.functional.binary_cross_entropy( topic_probs, next_sym.float(), weight=topics_weight) ans = (topic_probs > 0.5).long() # his_symptoms bs * sym_size? # his_mask = torch.where(his_symptoms > 0, torch.full_like(his_symptoms, 0), torch.full_like(his_symptoms,1)).long() # 隐藏句子节点 # his_mask his_sentence_mask = torch.zeros(bs, self.sen_num).long() total_mask = torch.cat( (torch.ones(bs, self.topic).long(), his_sentence_mask), -1) if self.training: aa = next_sym.long() else: aa = ans # total_mask = torch.ones(bs, self.symp_size).cuda() # total_mask = total_mask.long() & his_mask.long() topic_embedding = aa.float().matmul(self.symp_state) topic_hidden = last_h # 计算topic的f1, acc, rec pre_total = torch.sum(ans).item() true_total = torch.sum(next_sym).item() pre_right = torch.sum((ans == next_sym).long() * next_sym).item() # print(pre_total,pre_right) self.topic_acc(pre_right, pre_total) self.topic_rec(pre_right, true_total) acc = self.topic_acc.get_metric(False) rec = self.topic_rec.get_metric(False) f1 = 0. if acc + rec > 0: f1 = acc * rec * 2 / (acc + rec) self.topic_f1(f1) # Encoding source_tokens # embedded_input = self._source_embedder(source_tokens) source_mask = util.get_text_field_mask(source_tokens) # encoder_outputs = self._encoder(embedded_input, source_mask) # final_encoder_output = util.get_final_encoder_states(encoder_outputs, source_mask, self._encoder.is_bidirectional()) state = { "source_mask": source_mask, "encoder_outputs": topic_embedding, # bs * seq_len * dim "decoder_hidden": torch.cat((topic_embedding, dialog_hidden), -1), # bs * dim hred的输出 # "decoder_hidden": torch.sum(last_h * total_mask.float(), 1), "decoder_context": topic_embedding.new_zeros(bs, self._decoder_output_dim), # "decoder_context": topic_embedding, "topic_embedding": topic_embedding } # state[''] = topic_embedding # 获取一次decoder output_dict = self._forward_loop(state, topic_hidden, total_mask.cuda(), target_tokens) best_predictions = output_dict["predictions"] # output something references, hypothesis = [], [] for i in range(bs): cut_hypo = best_predictions[i][:] if self._end_index in list(best_predictions[i]): cut_hypo = best_predictions[i][:list(best_predictions[i]). index(self._end_index)] hypothesis.append([ self.vocab.get_token_from_index(idx.item()) for idx in cut_hypo ]) flag = 1 for i in range(bs): cut_ref = target_tokens['tokens'][1:] if self._end_index in list(target_tokens['tokens'][i]): cut_ref = target_tokens['tokens'][i][ 1:list(target_tokens['tokens'][i]).index(self._end_index)] references.append([ self.vocab.get_token_from_index(idx.item()) for idx in cut_ref ]) if random.random() <= 0.001 and not self.training and flag == 1: flag = 0 for jj in range(i): print('___hypo___', ''.join(hypothesis[jj]), end=' ## ') print(''.join(references[jj])) print("") self.bleu_aver(references, hypothesis) self.bleu1(references, hypothesis) self.bleu2(references, hypothesis) self.bleu4(references, hypothesis) self.kd_metric(references, hypothesis) self.dink1(hypothesis) self.dink2(hypothesis) if self.training: output_dict['loss'] = output_dict['loss'] + 8 * topic_loss else: output_dict['loss'] = topic_loss return output_dict
def forward( self, # type: ignore source: TextFieldTensors, **target_tokens: Dict[str, TextFieldTensors], ) -> Dict[str, torch.Tensor]: """ Decoder logic for producing the target sequences. # Parameters source : `TextFieldTensors` The output of `TextField.as_array()` applied on the source `TextField`. This will be passed through a `TextFieldEmbedder` and then through an encoder. target_tokens : `Dict[str, TextFieldTensors]`: Dictionary from name to output of `Textfield.as_array()` applied on target `TextField`. We assume that the target tokens are also represented as a `TextField`. """ # (batch_size, input_sequence_length, embedding_dim) embedded_input = self._embedding_dropout(self._source_embedder(source)) source_mask = get_text_field_mask(source) # (batch_size, encoder_output_dim) final_encoder_output = self._encoder(embedded_input, source_mask) output_dict = {} # Perform greedy search so we can get the loss. if target_tokens: if target_tokens.keys() != self._states.keys(): target_only = target_tokens.keys() - self._states.keys() states_only = self._states.keys() - target_tokens.keys() raise Exception( "Mismatch between target_tokens and self._states. Keys in " + f"targets only: {target_only} Keys in states only: {states_only}" ) total_loss = 0 for name, state in self._states.items(): loss = self.greedy_search( final_encoder_output=final_encoder_output, target_tokens=target_tokens[name], target_embedder=state.embedder, decoder_cell=state.decoder_cell, output_projection_layer=state.output_projection_layer, ) total_loss += loss output_dict[f"{name}_loss"] = loss # Use mean loss (instead of the sum of the losses) to be comparable to the paper. output_dict["loss"] = total_loss / len(self._states) # Perform beam search to obtain the predictions. if not self.training: batch_size = final_encoder_output.size()[0] for name, state in self._states.items(): start_predictions = final_encoder_output.new_full( (batch_size, ), fill_value=self._start_index, dtype=torch.long) start_state = {"decoder_hidden": final_encoder_output} # (batch_size, 10, num_decoding_steps) all_top_k_predictions, log_probabilities = self._beam_search.search( start_predictions, start_state, state.take_step) if target_tokens: self._update_recall(all_top_k_predictions, target_tokens[name], state.recall) output_dict[ f"{name}_top_k_predictions"] = all_top_k_predictions output_dict[ f"{name}_top_k_log_probabilities"] = log_probabilities return output_dict
def forward( self, # type: ignore question_field: Dict[str, torch.LongTensor], visual_feat: torch.Tensor, pos: torch.Tensor, image_id: List[str], gold_question_attentions: torch.Tensor = None, identifier: List[str] = None, logical_form: List[str] = None, actions: List[List[ProductionRule]] = None, target_action_sequence: torch.LongTensor = None, gold_object_choices: torch.Tensor = None, denotation: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ batch_size, obj_num, feat_size = visual_feat.size() assert obj_num == 36 and feat_size == 2048 text_masks = util.get_text_field_mask(question_field) (l_orig, v_orig, text, vis_only), x_orig = self._encoder( question_field[self._tokens_namespace], text_masks, visual_feat, pos) text_masks = text_masks.float() # NOTE: Taking the lxmert output before cross modality layer (which is the same for both images) # Can also try concatenating (dim=-1) the two encodings encoded_sentence = text initial_state = self._get_initial_state(encoded_sentence, text_masks, actions) initial_state.debug_info = [[] for _ in range(batch_size)] if target_action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequence = target_action_sequence.squeeze(-1) target_mask = target_action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} losses = [] if (self.training or self._use_gold_program_for_eval ) and target_action_sequence is not None: # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we # unsqueeze it for the MML trainer. search = ConstrainedBeamSearch( beam_size=None, allowed_sequences=target_action_sequence.unsqueeze(1), allowed_sequence_mask=target_mask.unsqueeze(1), ) final_states = search.search(initial_state, self._transition_function) if self._training_batches_so_far < self._num_parse_only_batches: for batch_index in range(batch_size): if not final_states[batch_index]: logger.error( f"No pogram found for batch index {batch_index}") continue losses.append(-final_states[batch_index][0].score[0]) else: final_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=False, ) action_mapping = {} for action_index, action in enumerate(actions[0]): action_mapping[action_index] = action[0] outputs: Dict[str, Any] = {"action_mapping": action_mapping} outputs["best_action_sequence"] = [] outputs["debug_info"] = [] if self._nmn_settings["mask_non_attention"]: zero_one_mult = torch.zeros_like(gold_question_attentions) zero_one_mult.copy_(gold_question_attentions) zero_one_mult[:, :, 0] = 1.0 # sep_indices = text_masks.argmax(1).long() sep_indices = ( (text_masks.long() * (1 + torch.arange(text_masks.shape[1]).unsqueeze(0).repeat( batch_size, 1).to(text_masks.device))).argmax(1).long()) sep_indices = (sep_indices.unsqueeze(1).repeat( 1, gold_question_attentions.shape[2]).unsqueeze(1).repeat( 1, gold_question_attentions.shape[1], 1)) indices_dim2 = (torch.arange( gold_question_attentions.shape[2]).unsqueeze(0).repeat( gold_question_attentions.shape[0], gold_question_attentions.shape[1], 1, ).to(sep_indices.device).long()) zero_one_mult = torch.where( sep_indices == indices_dim2, torch.ones_like(zero_one_mult), zero_one_mult, ).float() reshaped_questions = ( question_field[self._tokens_namespace].unsqueeze(1).repeat( 1, gold_question_attentions.shape[1], 1).view(-1, gold_question_attentions.shape[-1])) reshaped_visual_feat = (visual_feat.unsqueeze(1).repeat( 1, gold_question_attentions.shape[1], 1, 1).view(-1, obj_num, visual_feat.shape[-1])) reshaped_pos = (pos.unsqueeze(1).repeat( 1, gold_question_attentions.shape[1], 1, 1).view(-1, obj_num, pos.shape[-1])) zero_one_mult = zero_one_mult.view( -1, gold_question_attentions.shape[-1]) q_att_filter = zero_one_mult.sum(1) > 2 (l_relevant, v_relevant, _, _), x_relevant = self._encoder( reshaped_questions[q_att_filter, :], zero_one_mult[q_att_filter, :], reshaped_visual_feat[q_att_filter, :, :], reshaped_pos[q_att_filter, :, :], ) l = [{} for _ in range(batch_size)] v = [{} for _ in range(batch_size)] x = [{} for _ in range(batch_size)] count = 0 batch_index = -1 for i in range(zero_one_mult.shape[0]): module_num = i % target_action_sequence.shape[1] if module_num == 0: batch_index += 1 if q_att_filter[i].item(): l[batch_index][module_num] = l_relevant[count] v[batch_index][module_num] = v_relevant[count] x[batch_index][module_num] = x_relevant[count] count += 1 else: l = l_orig v = v_orig x = x_orig for batch_index in range(batch_size): if (self.training and self._training_batches_so_far < self._num_parse_only_batches): continue if not final_states[batch_index]: logger.error(f"No pogram found for batch index {batch_index}") outputs["best_action_sequence"].append([]) outputs["debug_info"].append([]) continue world = VisualReasoningGqaLanguage( l[batch_index], v[batch_index], x[batch_index], # initial_state.rnn_state[batch_index].encoder_outputs[batch_index], self._language_parameters, pos[batch_index], nmn_settings=self._nmn_settings, ) denotation_log_prob_list = [] # TODO(mattg): maybe we want to limit the number of states we evaluate (programs we # execute) at test time, just for efficiency. for state_index, state in enumerate(final_states[batch_index]): action_indices = state.action_history[0] action_strings = [ action_mapping[action_index] for action_index in action_indices ] # Shape: (num_denotations,) assert len(action_strings) == len(state.debug_info[0]) # Plug in gold question attentions for i in range(len(state.debug_info[0])): if gold_question_attentions[batch_index, i, :].sum() > 0: state.debug_info[0][i]["question_attention"] = ( gold_question_attentions[batch_index, i, :].float() / gold_question_attentions[batch_index, i, :].sum()) elif self._nmn_settings["mask_non_attention"] and ( action_strings[i][-4:] == "find" or action_strings[i][-6:] == "filter" or action_strings[i][-13:] == "with_relation"): state.debug_info[0][i]["question_attention"] = ( torch.ones_like( gold_question_attentions[batch_index, i, :]).float() / gold_question_attentions[batch_index, i, :].numel()) l[batch_index][i] = l_orig[batch_index] v[batch_index][i] = v_orig[batch_index] x[batch_index][i] = x_orig[batch_index] world = VisualReasoningGqaLanguage( l[batch_index], v[batch_index], x[batch_index], # initial_state.rnn_state[batch_index].encoder_outputs[batch_index], self._language_parameters, pos[batch_index], nmn_settings=self._nmn_settings, ) # print(action_strings) state_denotation_log_probs = world.execute_action_sequence( action_strings, state.debug_info[0]) # prob2 = world.execute_action_sequence(action_strings, state.debug_info[0]) # P(denotation | parse) * P(parse | question) denotation_log_prob_list.append(state_denotation_log_probs) if not self._use_gold_program_for_eval: denotation_log_prob_list[-1] += state.score[0] if state_index == 0: outputs["best_action_sequence"].append(action_strings) outputs["debug_info"].append(state.debug_info[0]) if target_action_sequence is not None: targets = target_action_sequence[batch_index].data program_correct = self._action_history_match( action_indices, targets) self._program_accuracy(program_correct) # P(denotation | parse) * P(parse | question) for the all programs on the beam. # Shape: (beam_size, num_denotations) denotation_log_probs = torch.stack(denotation_log_prob_list) # \Sum_parse P(denotation | parse) * P(parse | question) = P(denotation | question) # Shape: (num_denotations,) marginalized_denotation_log_probs = util.logsumexp( denotation_log_probs, dim=0) if denotation is not None: loss = (self.loss( state_denotation_log_probs.unsqueeze(0), denotation[batch_index].unsqueeze(0).float(), ).view(1) * self._denotation_loss_multiplier) losses.append(loss) self._denotation_accuracy( torch.tensor([ 1 - state_denotation_log_probs, state_denotation_log_probs ]).to(denotation.device), denotation[batch_index], ) if gold_object_choices is not None: gold_objects = gold_object_choices[batch_index, :, :] predicted_objects = torch.zeros_like(gold_objects) for index in world.object_scores: predicted_objects[ index, :] = world.object_scores[index] obj_exists = gold_objects.max(1)[0] > 0 # Only look at modules where at least one of the proposals has the object of interest predicted_objects = predicted_objects[obj_exists, :] gold_objects = gold_objects[obj_exists, :] gold_objects = gold_objects.view(-1) predicted_objects = predicted_objects.view(-1) if gold_objects.numel() > 0: loss += self._obj_loss_multiplier * self.loss( predicted_objects, (gold_objects.float() + 1) / 2) self._proposal_accuracy( torch.cat( ( 1.0 - predicted_objects.view(-1, 1), predicted_objects.view(-1, 1), ), dim=-1, ), (gold_objects + 1) / 2, ) if losses: outputs["loss"] = torch.stack(losses).mean() if self.training: self._training_batches_so_far += 1 return outputs
def forward( self, query: Dict[str, torch.LongTensor], docs: Dict[str, torch.LongTensor], dataset: List[str] = [], labels: Optional[Dict[str, torch.LongTensor]] = None, scores: Optional[Dict[str, torch.Tensor]] = None, relevant_ignored: Optional[torch.Tensor] = None, irrelevant_ignored: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: # label masks ls_mask = get_text_field_mask(docs) # (batch_size, num_docs, doc_length) ds_mask = get_text_field_mask(docs, num_wrapping_dims=1) # (batch_size, num_docs, doc_length, embedding_dim) ds_embedded = self.embedder(docs) # (batch_size, num_docs, doc_length, transform_dim) batch_size, num_docs, doc_length, embedding_dim = ds_embedded.shape # (batch_size * num_docs, doc_length, transform_dim) ds_embedded = ds_embedded.view(batch_size * num_docs, doc_length, embedding_dim) ds_mask = ds_mask.view(batch_size * num_docs, doc_length) if self.idf_embedder is not None: ds_idfs = self.idf_embedder(docs) ds_idfs = ds_idfs.view(batch_size * num_docs, doc_length, 1).repeat(1, 1, embedding_dim) ds_embedded = ds_embedded * ds_idfs # (batch_size, query_length) qs_mask = get_text_field_mask(query) _, query_length = qs_mask.shape qs_mask = qs_mask.unsqueeze(1).repeat(1, num_docs, 1).view(batch_size * num_docs, -1) # (batch_size, query_length, embedding_dim) qs_embedded = self.embedder(query) # (batch_size, num_docs, query_length, embedding_dim) qs_embedded = qs_embedded.unsqueeze(1).repeat(1, num_docs, 1, 1) # (batch_size, num_docs, query_length, embedding_dim) qs_embedded = qs_embedded.view(batch_size * num_docs, query_length, embedding_dim) if self.idf_embedder is not None: qs_idfs = self.idf_embedder(query).unsqueeze(1).repeat( 1, num_docs, 1, 1) qs_idfs = qs_idfs.view(batch_size * num_docs, query_length, 1).repeat(1, 1, embedding_dim) qs_embedded = qs_embedded * qs_idfs logits = self.scorer(qs_embedded, ds_embedded, qs_mask, ds_mask) #logits = F.log_softmax(logits,dim=1) scores = torch.exp(scores / self.temperature).view( batch_size * num_docs, -1) logits = torch.cat([logits, scores], dim=1) logits = self.final_scorer(logits).view(batch_size, num_docs) output_dict = {'logits': logits} if labels is not None: # filter out to only the metrics we care about # if self.training: # if self.ranking_loss: # loss = self.loss(logits[:, 0], logits[:, 1], labels.float()*-2.+1.) # else: # loss = self.loss(logits, labels.squeeze(-1).long()) #self.metrics['accuracy'](logits, labels.squeeze(-1)) # else: # # at validation time, we can't compute a proper loss # loss = torch.Tensor([0.]) # for metric in self.training_metrics[False]: # self.metrics[metric](logits, labels.squeeze(-1).long(), ls_mask, relevant_ignored, irrelevant_ignored) #output_dict['loss'] = self.loss(logits[:, 0], logits[:, 1], labels.float()*2.+1.) output_dict['loss'] = self.loss(logits, labels) if labels is not None: sfl = F.log_softmax(logits, dim=1) #print(sfl) output_dict['accuracy'] = self.metrics['accuracy'](sfl, labels) return output_dict
def forward( self, # type: ignore user_utterance: Dict[str, torch.LongTensor], prev_user_utterance: Dict[str, torch.LongTensor], prev_sys_utterance: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- user_utterance : Dict[str, Variable], required The output of ``TextField.as_array()``. prev_user_utterance : Dict[str, Variable], required The output of ``TextField.as_array()``. prev_sys_utterance : Dict[str, Variable], required The output of ``TextField.as_array()``. label : Variable, optional (default = None) A variable representing the intent label for each instance in the batch. Returns ------- An output dictionary consisting of: class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the label classes for each instance. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_user_utterance = self.text_field_embedder(user_utterance) user_utterance_mask = util.get_text_field_mask(user_utterance) encoded_user_utterance = self.user_utterance_encoder( embedded_user_utterance, user_utterance_mask) embedded_prev_user_utterance = self.text_field_embedder( prev_user_utterance) prev_user_utterance_mask = util.get_text_field_mask( prev_user_utterance) encoded_prev_user_utterance = self.prev_user_utterance_encoder( embedded_prev_user_utterance, prev_user_utterance_mask) embedded_prev_sys_utterance = self.text_field_embedder( prev_sys_utterance) prev_sys_utterance_mask = util.get_text_field_mask(prev_sys_utterance) encoded_prev_sys_utterance = self.prev_sys_utterance_encoder( embedded_prev_sys_utterance, prev_sys_utterance_mask) logits = self.classifier_feedforward( torch.cat([ encoded_user_utterance, encoded_prev_user_utterance, encoded_prev_sys_utterance ], dim=-1)) class_probs = F.softmax(logits, dim=1) output_dict = {'logits': logits} if label is not None: loss = self.loss(logits, label) output_dict["loss"] = loss # compute F1 per label for i in range(self.num_classes): metric = self.label_f1_metrics[self.vocab.get_token_from_index( index=i, namespace="labels")] metric(class_probs, label) self.label_accuracy(logits, label) return output_dict
def forward(self, text, spans, ner_labels, coref_labels, relation_labels, trigger_labels, argument_labels, metadata): """ TODO(dwadden) change this. """ # For co-training on Ontonotes, need to change the loss weights depending on the data coming # in. This is a hack but it will do for now. if self._co_train: if self.training: dataset = [entry["dataset"] for entry in metadata] assert len(set(dataset)) == 1 dataset = dataset[0] assert dataset in ["ace", "ontonotes"] if dataset == "ontonotes": self._loss_weights = dict(coref=1, ner=0, relation=0, events=0) else: self._loss_weights = self._permanent_loss_weights # This assumes that there won't be any co-training data in the dev and test sets, and that # coref propagation will still happen even when the coref weight is set to 0. else: self._loss_weights = self._permanent_loss_weights # In AllenNLP, AdjacencyFields are passed in as floats. This fixes it. relation_labels = relation_labels.long() argument_labels = argument_labels.long() # If we're doing Bert, get the sentence class token as part of the text embedding. This will # break if we use Bert together with other embeddings, but that won't happen much. if "bert-offsets" in text: # NOTE(dwadden) This operation mutates the text. We clone it so that the input isn't # mutated; otherwise, successive `forward` calls on the same data variable would give # different results because the data got mutated silently. new_text = {} for k, v in text.items(): new_text[k] = v.clone() text = new_text offsets = text["bert-offsets"] sent_ix = torch.zeros(offsets.size(0), device=offsets.device, dtype=torch.long).unsqueeze(1) padded_offsets = torch.cat([sent_ix, offsets], dim=1) text["bert-offsets"] = padded_offsets padded_embeddings = self._text_field_embedder(text) cls_embeddings = padded_embeddings[:, 0, :] text_embeddings = padded_embeddings[:, 1:, :] else: text_embeddings = self._text_field_embedder(text) cls_embeddings = torch.zeros( [text_embeddings.size(0), text_embeddings.size(2)], device=text_embeddings.device) text_embeddings = self._lexical_dropout(text_embeddings) # Shape: (batch_size, max_sentence_length) text_mask = util.get_text_field_mask(text).float() sentence_group_lengths = text_mask.sum(dim=1).long() sentence_lengths = 0 * text_mask.sum(dim=1).long() for i in range(len(metadata)): sentence_lengths[ i] = metadata[i]["end_ix"] - metadata[i]["start_ix"] for k in range(sentence_lengths[i], sentence_group_lengths[i]): text_mask[i][k] = 0 max_sentence_length = sentence_lengths.max().item() # TODO(Ulme) Speed this up by tensorizing new_text_embeddings = torch.zeros([ text_embeddings.shape[0], max_sentence_length, text_embeddings.shape[2] ], device=text_embeddings.device) for i in range(len(new_text_embeddings)): new_text_embeddings[ i][0:metadata[i]["end_ix"] - metadata[i]["start_ix"]] = text_embeddings[i][ metadata[i]["start_ix"]:metadata[i]["end_ix"]] #max_sent_len = max(sentence_lengths) #the_list = [list(k+metadata[i]["start_ix"] if k < max_sent_len else 0 for k in range(text_embeddings.shape[1])) for i in range(len(metadata))] #import ipdb; ipdb.set_trace() #text_embeddings = torch.gather(text_embeddings, 1, torch.tensor(the_list, device=text_embeddings.device).unsqueeze(2).repeat(1, 1, text_embeddings.shape[2])) text_embeddings = new_text_embeddings # Only keep the text embeddings that correspond to actual tokens. # text_embeddings = text_embeddings[:, :max_sentence_length, :].contiguous() text_mask = text_mask[:, :max_sentence_length].contiguous() # Shape: (batch_size, max_sentence_length, encoding_dim) contextualized_embeddings = self._lstm_dropout( self._context_layer(text_embeddings, text_mask)) assert spans.max() < contextualized_embeddings.shape[1] if self._attentive_span_extractor is not None: # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) if self._attentive_span_extractor is not None: # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) else: span_embeddings = endpoint_span_embeddings # TODO(Ulme) try normalizing span embeddeings #span_embeddings = span_embeddings.abs().sum(dim=-1).unsqueeze(-1) # Make calls out to the modules to get results. output_coref = {'loss': 0} output_ner = {'loss': 0} output_relation = {'loss': 0} output_events = {'loss': 0} # Prune and compute span representations for coreference module if self._loss_weights["coref"] > 0 or self._coref.coref_prop > 0: output_coref, coref_indices = self._coref.compute_representations( spans, span_mask, span_embeddings, sentence_lengths, coref_labels, metadata) # Prune and compute span representations for relation module if self._loss_weights["relation"] > 0 or self._relation.rel_prop > 0: output_relation = self._relation.compute_representations( spans, span_mask, span_embeddings, sentence_lengths, relation_labels, metadata) # Propagation of global information to enhance the span embeddings if self._coref.coref_prop > 0: # TODO(Ulme) Implement Coref Propagation output_coref = self._coref.coref_propagation(output_coref) span_embeddings = self._coref.update_spans(output_coref, span_embeddings, coref_indices) if self._relation.rel_prop > 0: output_relation = self._relation.relation_propagation( output_relation) span_embeddings = self.update_span_embeddings( span_embeddings, span_mask, output_relation["top_span_embeddings"], output_relation["top_span_mask"], output_relation["top_span_indices"]) # Make predictions and compute losses for each module if self._loss_weights['ner'] > 0: output_ner = self._ner(spans, span_mask, span_embeddings, sentence_lengths, ner_labels, metadata) if self._loss_weights['coref'] > 0: output_coref = self._coref.predict_labels(output_coref, metadata) if self._loss_weights['relation'] > 0: output_relation = self._relation.predict_labels( relation_labels, output_relation, metadata) if self._loss_weights['events'] > 0: # Make the trigger embeddings the same size as the argument embeddings to make # propagation easier. if self._events._span_prop._n_span_prop > 0: trigger_embeddings = contextualized_embeddings.repeat(1, 1, 2) trigger_widths = torch.zeros( [trigger_embeddings.size(0), trigger_embeddings.size(1)], device=trigger_embeddings.device, dtype=torch.long) trigger_width_embs = self._endpoint_span_extractor._span_width_embedding( trigger_widths) trigger_width_embs = trigger_width_embs.detach() trigger_embeddings = torch.cat( [trigger_embeddings, trigger_width_embs], dim=-1) else: trigger_embeddings = contextualized_embeddings output_events = self._events(text_mask, trigger_embeddings, spans, span_mask, span_embeddings, cls_embeddings, sentence_lengths, output_ner, trigger_labels, argument_labels, ner_labels, metadata) if "loss" not in output_coref: output_coref["loss"] = 0 if "loss" not in output_relation: output_relation["loss"] = 0 # TODO(dwadden) just did this part. loss = (self._loss_weights['coref'] * output_coref['loss'] + self._loss_weights['ner'] * output_ner['loss'] + self._loss_weights['relation'] * output_relation['loss'] + self._loss_weights['events'] * output_events['loss']) output_dict = dict(coref=output_coref, relation=output_relation, ner=output_ner, events=output_events) output_dict['loss'] = loss # Check to see if event predictions are globally compatible (argument labels are compatible # with NER tags and trigger tags). # if self._loss_weights["ner"] > 0 and self._loss_weights["events"] > 0: # decoded_ner = self._ner.decode(output_dict["ner"]) # decoded_events = self._events.decode(output_dict["events"]) # self._joint_metrics(decoded_ner, decoded_events) return output_dict
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``. label : torch.LongTensor, optional (default = None) A variable representing the label for each instance in the batch. Returns ------- An output dictionary consisting of: class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the label classes for each instance. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ text_mask = util.get_text_field_mask(tokens).float() # Pop elmo tokens, since elmo embedder should not be present. elmo_tokens = tokens.pop("elmo", None) embedded_text = self._text_field_embedder(tokens) # Add the "elmo" key back to "tokens" if not None, since the tests and the # subsequent training epochs rely not being modified during forward() if elmo_tokens is not None: tokens["elmo"] = elmo_tokens # Create ELMo embeddings if applicable if self._elmo: if elmo_tokens is not None: elmo_representations = self._elmo(elmo_tokens)["elmo_representations"] # Pop from the end is more performant with list if self._use_integrator_output_elmo: integrator_output_elmo = elmo_representations.pop() if self._use_input_elmo: input_elmo = elmo_representations.pop() assert not elmo_representations else: raise ConfigurationError( "Model was built to use Elmo, but input text is not tokenized for Elmo.") if self._use_input_elmo: embedded_text = torch.cat([embedded_text, input_elmo], dim=-1) dropped_embedded_text = self._embedding_dropout(embedded_text) pre_encoded_text = self._pre_encode_feedforward(dropped_embedded_text) encoded_tokens = self._encoder(pre_encoded_text, text_mask) # Compute biattention. This is a special case since the inputs are the same. attention_logits = encoded_tokens.bmm(encoded_tokens.permute(0, 2, 1).contiguous()) attention_weights = util.last_dim_softmax(attention_logits, text_mask) encoded_text = util.weighted_sum(encoded_tokens, attention_weights) # Build the input to the integrator integrator_input = torch.cat([encoded_tokens, encoded_tokens - encoded_text, encoded_tokens * encoded_text], 2) integrated_encodings = self._integrator(integrator_input, text_mask) # Concatenate ELMo representations to integrated_encodings if specified if self._use_integrator_output_elmo: integrated_encodings = torch.cat([integrated_encodings, integrator_output_elmo], dim=-1) # Simple Pooling layers max_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), -1e7) max_pool = torch.max(max_masked_integrated_encodings, 1)[0] min_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), +1e7) min_pool = torch.min(min_masked_integrated_encodings, 1)[0] mean_pool = torch.sum(integrated_encodings, 1) / torch.sum(text_mask, 1, keepdim=True) # Self-attentive pooling layer # Run through linear projection. Shape: (batch_size, sequence length, 1) # Then remove the last dimension to get the proper attention shape (batch_size, sequence length). self_attentive_logits = self._self_attentive_pooling_projection( integrated_encodings).squeeze(2) self_weights = util.masked_softmax(self_attentive_logits, text_mask) self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights) pooled_representations = torch.cat([max_pool, min_pool, mean_pool, self_attentive_pool], 1) pooled_representations_dropped = self._integrator_dropout(pooled_representations) logits = self._output_layer(pooled_representations_dropped) class_probabilities = F.softmax(logits, dim=-1) output_dict = {'logits': logits, 'class_probabilities': class_probabilities} if label is not None: loss = self.loss(logits, label) for metric in self.metrics.values(): metric(logits, label) output_dict["loss"] = loss return output_dict
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. valid_actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. """ embedded_utterance = self._utterance_embedder(tokens) mask = util.get_text_field_mask(tokens).float() batch_size = embedded_utterance.size(0) # (batch_size, num_tokens, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask)) initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) target_mask = action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, Any] = {} if action_sequence is not None: # target_action_sequence is of shape (batch_size, 1, target_sequence_length) # here after we unsqueeze it for the MML trainer. loss_output = self._decoder_trainer.decode(initial_state, self._transition_function, (action_sequence.unsqueeze(1), target_mask.unsqueeze(1))) outputs.update(loss_output) if not self.training: action_mapping = [] for batch_actions in valid_actions: batch_action_mapping = {} for action_index, action in enumerate(batch_actions): batch_action_mapping[action_index] = action[0] action_mapping.append(batch_action_mapping) outputs['action_mapping'] = action_mapping # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=True) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['predicted_sql_query'] = [] outputs['sql_queries'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._denotation_accuracy(0) self._valid_sql_query(0) self._action_similarity(0) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[i][action_index] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings) if action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) outputs['best_action_sequence'].append(action_strings) outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True)) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore return outputs
def forward( self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder( embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax( similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat( [embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat( [embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward( hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze( -1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = { "label_logits": label_logits, "label_probs": label_probs } if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label.squeeze(-1)) output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore sentence1: Dict[str, torch.LongTensor], span_field1: torch.LongTensor, span1_text: Dict[str, torch.LongTensor], sentence2: Dict[str, torch.LongTensor], span_field2: torch.LongTensor, label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- sentence1 : Dict[str, Variable], required The output of ``TextField.as_array()``. span_field1 : torch.LongTensor, required The span field for sentence 1 span1_text : Dict[str, Variable], required The output of ``TextField.as_array()``. sentence2: Dict[str, Variable], required The output of ``TextField.as_array()``. span_field2 : torch.LongTensor, required The span field for sentence 2 label : Variable, optional (default = None) A variable representing the label for each instance in the batch. Returns ------- An output dictionary consisting of: class_probabilities : torch.FloatTensor A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the label classes for each instance. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_sentence1 = self.text_field_embedder(sentence1) embedded_sentence2 = self.text_field_embedder(sentence2) # Encode the sequence if self.seq2seq_encoder: sentence1_mask = util.get_text_field_mask(sentence1) encoded_sentence1 = self.seq2seq_encoder(embedded_sentence1, sentence1_mask).data sentence2_mask = util.get_text_field_mask(sentence2) encoded_sentence2 = self.seq2seq_encoder(embedded_sentence2, sentence2_mask).data # Using an embedder that returns a vector for each token. # We take the span from it. else: encoded_sentence1 = embedded_sentence1 encoded_sentence2 = embedded_sentence2 # Extract the span span1 = self.span_extractor(encoded_sentence1, span_field1) span2 = self.span_extractor(encoded_sentence2, span_field2) # Endpoint extractor should return [2, batch_size, emb_dim]. # For a single word span, concat a zero vector to keep the lengths. if span2.size() != span1.size(): span2 = torch.cat( [encoded_sentence2, torch.zeros_like(encoded_sentence2)], dim=-1) input = torch.cat([span1, span2], dim=-1).squeeze(0) logits = self.classifier_feedforward(input) output_dict = {'logits': logits} if label is not None: loss = self.loss(logits, label) for metric in self.metrics.values(): metric(logits, label) output_dict["loss"] = loss return output_dict
def forward(self, # type: ignore sentences: torch.LongTensor, labels: torch.IntTensor = None, confidences: torch.Tensor = None, additional_features: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- TODO: add description Returns ------- An output dictionary consisting of: loss : torch.FloatTensor, optional A scalar loss to be optimised. """ # =========================================================================================================== # Layer 1: For each sentence, participant pair: create a Glove embedding for each token # Input: sentences # Output: embedded_sentences # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size embedded_sentences = self.text_field_embedder(sentences) mask = get_text_field_mask(sentences, num_wrapping_dims=1).float() batch_size, num_sentences, _, _ = embedded_sentences.size() if self.use_sep: # The following code collects vectors of the SEP tokens from all the examples in the batch, # and arrange them in one list. It does the same for the labels and confidences. # TODO: replace 103 with '[SEP]' sentences_mask = sentences['bert'] == 103 # mask for all the SEP tokens in the batch embedded_sentences = embedded_sentences[sentences_mask] # given batch_size x num_sentences_per_example x sent_len x vector_len # returns num_sentences_per_batch x vector_len assert embedded_sentences.dim() == 2 num_sentences = embedded_sentences.shape[0] # for the rest of the code in this model to work, think of the data we have as one example # with so many sentences and a batch of size 1 batch_size = 1 embedded_sentences = embedded_sentences.unsqueeze(dim=0) embedded_sentences = self.dropout(embedded_sentences) if labels is not None: if self.labels_are_scores: labels_mask = labels != 0.0 # mask for all the labels in the batch (no padding) else: labels_mask = labels != -1 # mask for all the labels in the batch (no padding) labels = labels[labels_mask] # given batch_size x num_sentences_per_example return num_sentences_per_batch assert labels.dim() == 1 if confidences is not None: confidences = confidences[labels_mask] assert confidences.dim() == 1 if additional_features is not None: additional_features = additional_features[labels_mask] assert additional_features.dim() == 2 num_labels = labels.shape[0] if num_labels != num_sentences: # bert truncates long sentences, so some of the SEP tokens might be gone assert num_labels > num_sentences # but `num_labels` should be at least greater than `num_sentences` logger.warning(f'Found {num_labels} labels but {num_sentences} sentences') labels = labels[:num_sentences] # Ignore some labels. This is ok for training but bad for testing. # We are ignoring this problem for now. # TODO: fix, at least for testing # do the same for `confidences` if confidences is not None: num_confidences = confidences.shape[0] if num_confidences != num_sentences: assert num_confidences > num_sentences confidences = confidences[:num_sentences] # and for `additional_features` if additional_features is not None: num_additional_features = additional_features.shape[0] if num_additional_features != num_sentences: assert num_additional_features > num_sentences additional_features = additional_features[:num_sentences] # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1 labels = labels.unsqueeze(dim=0) if confidences is not None: confidences = confidences.unsqueeze(dim=0) if additional_features is not None: additional_features = additional_features.unsqueeze(dim=0) else: # ['CLS'] token embedded_sentences = embedded_sentences[:, :, 0, :] embedded_sentences = self.dropout(embedded_sentences) batch_size, num_sentences, _ = embedded_sentences.size() sent_mask = (mask.sum(dim=2) != 0) embedded_sentences = self.self_attn(embedded_sentences, sent_mask) if additional_features is not None: embedded_sentences = torch.cat((embedded_sentences, additional_features), dim=-1) label_logits = self.time_distributed_aggregate_feedforward(embedded_sentences) # label_logits: batch_size, num_sentences, num_labels if self.labels_are_scores: label_probs = label_logits else: label_probs = torch.nn.functional.softmax(label_logits, dim=-1) # Create output dictionary for the trainer # Compute loss and epoch metrics output_dict = {"action_probs": label_probs} # ===================================================================== if self.with_crf: # Layer 4 = CRF layer across labels of sentences in an abstract mask_sentences = (labels != -1) best_paths = self.crf.viterbi_tags(label_logits, mask_sentences) # # # Just get the tags and ignore the score. predicted_labels = [x for x, y in best_paths] # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}") label_loss = 0.0 if labels is not None: # Compute cross entropy loss flattened_logits = label_logits.view((batch_size * num_sentences), self.num_labels) flattened_gold = labels.contiguous().view(-1) if not self.with_crf: label_loss = self.loss(flattened_logits.squeeze(), flattened_gold) if confidences is not None: label_loss = label_loss * confidences.type_as(label_loss).view(-1) label_loss = label_loss.mean() flattened_probs = torch.softmax(flattened_logits, dim=-1) else: clamped_labels = torch.clamp(labels, min=0) log_likelihood = self.crf(label_logits, clamped_labels, mask_sentences) label_loss = -log_likelihood # compute categorical accuracy crf_label_probs = label_logits * 0. for i, instance_labels in enumerate(predicted_labels): for j, label_id in enumerate(instance_labels): crf_label_probs[i, j, label_id] = 1 flattened_probs = crf_label_probs.view((batch_size * num_sentences), self.num_labels) if not self.labels_are_scores: evaluation_mask = (flattened_gold != -1) self.label_accuracy(flattened_probs.float().contiguous(), flattened_gold.squeeze(-1), mask=evaluation_mask) # compute F1 per label for label_index in range(self.num_labels): label_name = self.vocab.get_token_from_index(namespace='labels', index=label_index) metric = self.label_f1_metrics[label_name] metric(flattened_probs, flattened_gold, mask=evaluation_mask) if labels is not None: output_dict["loss"] = label_loss output_dict['action_logits'] = label_logits return output_dict
def forward( self, # type: ignore question, passage, span_start=None, span_end=None, metadata=None): ######## 1/2. Embedding Layer ######## # 2. After add embedding, pass the embedded vector to Highway network. embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None ######## 3. Contextual Embedding Layer ######## # Encode input vectors into new representation H and U by using a BiLSTM. # Shape: (batch_size, 2 * encoding dim, question_length) encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) # Shape: (batch_size, 2 * encoding dim, paragraph_length) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) # get each token dim by accessing the last dimention. encoding_dim = encoded_question.size(-1) ######## 4. Attention Flow Layer ######## # Calculate similarity matrix for attention layer. # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Create weighted vectors attended by context to query attention. # Calculate C2Q(context to query) attention, 'a' # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax( passage_question_similarity, question_mask) # Weighted vector by C2Q attentions. \hat{U}_:t \sum_j a_{tj} U_{:j} # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Create weighted vectors attended by query to context attention. # Replaced masked values to avoid let them affect the result. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Calculate Q2C(query to context) attention, 'b' # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0] # Pass to the softmax layer. # Shape: (batch_size, passage_length) question_passage_attention = \ util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Merge attention vectors # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) # Add purple "linear ReLU layer" final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) # Bi-GRU in the paper residual_layer = self._dropout( self._residual_encoder(self._dropout(final_merged_passage), passage_mask)) self_atten_matrix = self._self_atten(residual_layer, residual_layer) # Expand mask for self-attention mask = (passage_mask.resize(batch_size, passage_length, 1) * passage_mask.resize(batch_size, 1, passage_length)) # Mask should have zeros on the diagonal. # torch.eye does not have a gpu implementation, so we are forced to use # the cpu one and .cuda(). Not sure if this matters for performance. eye = torch.eye(passage_length, passage_length) if mask.is_cuda: eye = eye.cuda() self_mask = Variable(eye).resize(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_atten_probs = util.last_dim_softmax(self_atten_matrix, mask) # Batch matrix multiplication: # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_atten_vecs = torch.matmul(self_atten_probs, residual_layer) # (extended_batch_size, passage_length, embedding_dim * 3) concatenated = torch.cat([ self_atten_vecs, residual_layer, residual_layer * self_atten_vecs ], dim=-1) # _merge_self_atten => (extended_batch_size, passage_length, # embedding_dim) residual_layer = F.relu(self._merge_self_atten(concatenated)) # print("residual", residual_layer.size()) final_merged_passage += residual_layer final_merged_passage = self._dropout(final_merged_passage) # Bi-GRU in paper start_rep = self._span_start_encoder(final_merged_passage, passage_lstm_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) end_rep = self._span_end_encoder( torch.cat([final_merged_passage, start_rep], dim=-1), passage_lstm_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } if span_start is not None: # Calculate the loss. # The training loss is the sum of the negative log probablities of # the true start and end indices by the predicted distributions, # everaged over all examples. loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) # Why need to be `torch.stack` self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], span_starts: torch.IntTensor, span_ends: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. span_starts : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 1), representing the start indices of candidate spans for mentions. Comes from a ``ListField[IndexField]`` of indices into the text of the document. span_ends : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 1), representing the end indices of candidate spans for mentions. Comes from a ``ListField[IndexField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = span_starts.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans, 1) span_mask = (span_starts >= 0).float() # IndexFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. span_starts = F.relu(span_starts.float()).long() span_ends = F.relu(span_ends.float()).long() # Shape: (batch_size, num_spans, embedding_size) span_embeddings = self._compute_span_representations( text_embeddings, text_mask, span_starts, span_ends) # Compute a score for whether each span is a mention, # making sure that masked spans have very low scores. # Shape: (batch_size, num_spans, 1) mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings)) mention_scores += span_mask.log() # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) # Shape: (batch_size, num_spans_to_keep) # These are indices (with values between 0 and num_spans) into # the span_embeddings tensor. top_span_indices = self._prune_and_sort_spans(mention_scores, num_spans_to_keep) # Now that we've decided which spans are actually mentions the next # few steps are reformatting all of our variables to be in terms of # num_spans_to_keep instead of num_spans, so we don't waste computation # on spans that we've already discarded. # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Select the span embeddings corresponding to the # top spans based on the mention scorer. # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = util.batched_index_select( span_embeddings, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, 1) # TODO(Mark): If we parameterised the mention scorer to score things in (0, inf) # I think we could get rid of the need for this mask entirely. top_span_mask = util.batched_index_select(span_mask, top_span_indices, flat_top_span_indices) top_span_mention_scores = util.batched_index_select( mention_scores, top_span_indices, flat_top_span_indices) top_span_starts = util.batched_index_select(span_starts, top_span_indices, flat_top_span_indices) top_span_ends = util.batched_index_select(span_ends, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, text_mask.is_cuda) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # Compute final predictions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = torch.cat([top_span_starts, top_span_ends], -1) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents } if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.last_dim_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood return output_dict
def _forward_loop( self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] = None ) -> Dict[str, torch.Tensor]: """ Make forward pass during training or do greedy search during prediction. Notes ----- We really only use the predictions from the method to test that beam search with a beam size of 1 gives the same results. """ # shape: (batch_size, max_input_sequence_length) source_mask = state["source_mask"] batch_size = source_mask.size()[0] if target_tokens: # shape: (batch_size, max_target_sequence_length) targets = target_tokens["tokens"] _, target_sequence_length = targets.size() # The last input from the target is either padding or the end symbol. # Either way, we don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps # Initialize target predictions with the start index. # shape: (batch_size,) last_predictions = source_mask.new_full((batch_size, ), fill_value=self._start_index) step_logits: List[torch.Tensor] = [] step_predictions: List[torch.Tensor] = [] for timestep in range(num_decoding_steps): if self.training and torch.rand( 1).item() < self._scheduled_sampling_ratio: # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio # during training. # shape: (batch_size,) input_choices = last_predictions elif not target_tokens: # shape: (batch_size,) input_choices = last_predictions else: # shape: (batch_size,) input_choices = targets[:, timestep] # shape: (batch_size, num_classes) output_projections, state = self._prepare_output_projections( input_choices, state) # list of tensors, shape: (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) # shape: (batch_size, num_classes) class_probabilities = F.softmax(output_projections, dim=-1) # shape (predicted_classes): (batch_size,) _, predicted_classes = torch.max(class_probabilities, 1) # shape (predicted_classes): (batch_size,) last_predictions = predicted_classes step_predictions.append(last_predictions.unsqueeze(1)) # shape: (batch_size, num_decoding_steps) predictions = torch.cat(step_predictions, 1) output_dict = {"predictions": predictions} if target_tokens: # shape: (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) # Compute loss. target_mask = util.get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss return output_dict
def forward(self, text, spans, metadata, ner_labels=None, coref_labels=None, relation_labels=None, trigger_labels=None, argument_labels=None): """ TODO(dwadden) change this. """ # In AllenNLP, AdjacencyFields are passed in as floats. This fixes it. if relation_labels is not None: relation_labels = relation_labels.long() if argument_labels is not None: argument_labels = argument_labels.long() # TODO(dwadden) Multi-document minibatching isn't supported yet. For now, get rid of the # extra dimension in the input tensors. Will return to this once the model runs. if len(metadata) > 1: raise NotImplementedError("Multi-document minibatching not yet supported.") metadata = metadata[0] spans = self._debatch(spans) # (n_sents, max_n_spans, 2) ner_labels = self._debatch(ner_labels) # (n_sents, max_n_spans) coref_labels = self._debatch(coref_labels) # (n_sents, max_n_spans) relation_labels = self._debatch(relation_labels) # (n_sents, max_n_spans, max_n_spans) trigger_labels = self._debatch(trigger_labels) # TODO(dwadden) argument_labels = self._debatch(argument_labels) # TODO(dwadden) # Encode using BERT, then debatch. # Since the data are batched, we use `num_wrapping_dims=1` to unwrap the document dimension. # (1, n_sents, max_sententence_length, embedding_dim) # TODO(dwadden) Deal with the case where the input is longer than 512. text_embeddings = self._embedder(text, num_wrapping_dims=1) # (n_sents, max_n_wordpieces, embedding_dim) text_embeddings = self._debatch(text_embeddings) # (n_sents, max_sentence_length) text_mask = self._debatch(util.get_text_field_mask(text, num_wrapping_dims=1).float()) sentence_lengths = text_mask.sum(dim=1).long() # (n_sents) span_mask = (spans[:, :, 0] >= 0).float() # (n_sents, max_n_spans) # SpanFields return -1 when they are used as padding. As we do some comparisons based on # span widths when we attend over the span representations that we generate from these # indices, we need them to be <= 0. This is only relevant in edge cases where the number of # spans we consider after the pruning stage is >= the total number of spans, because in this # case, it is possible we might consider a masked span. spans = F.relu(spans.float()).long() # (n_sents, max_n_spans, 2) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) span_embeddings = self._endpoint_span_extractor(text_embeddings, spans) # Make calls out to the modules to get results. output_coref = {'loss': 0} output_ner = {'loss': 0} output_relation = {'loss': 0} output_events = {'loss': 0} # Prune and compute span representations for coreference module if self._loss_weights["coref"] > 0 or self._coref.coref_prop > 0: output_coref, coref_indices = self._coref.compute_representations( spans, span_mask, span_embeddings, sentence_lengths, coref_labels, metadata) # Propagation of global information to enhance the span embeddings if self._coref.coref_prop > 0: output_coref = self._coref.coref_propagation(output_coref) span_embeddings = self._coref.update_spans( output_coref, span_embeddings, coref_indices) # Make predictions and compute losses for each module if self._loss_weights['ner'] > 0: output_ner = self._ner( spans, span_mask, span_embeddings, sentence_lengths, ner_labels, metadata) if self._loss_weights['coref'] > 0: output_coref = self._coref.predict_labels(output_coref, metadata) if self._loss_weights['relation'] > 0: output_relation = self._relation( spans, span_mask, span_embeddings, sentence_lengths, relation_labels, metadata) if self._loss_weights['events'] > 0: # The `text_embeddings` serve as representations for event triggers. output_events = self._events( text_mask, text_embeddings, spans, span_mask, span_embeddings, sentence_lengths, trigger_labels, argument_labels, ner_labels, metadata) # Use `get` since there are some cases where the output dict won't have a loss - for # instance, when doing prediction. loss = (self._loss_weights['coref'] * output_coref.get("loss", 0) + self._loss_weights['ner'] * output_ner.get("loss", 0) + self._loss_weights['relation'] * output_relation.get("loss", 0) + self._loss_weights['events'] * output_events.get("loss", 0)) # Multiply the loss by the weight multiplier for this document. weight = metadata.weight if metadata.weight is not None else 1.0 loss *= torch.tensor(weight) output_dict = dict(coref=output_coref, relation=output_relation, ner=output_ner, events=output_events) output_dict['loss'] = loss output_dict["metadata"] = metadata return output_dict
def forward(self, # type: ignore source_tokens: Dict[str, torch.LongTensor], target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Decoder logic for producing the entire target sequence. Parameters ---------- source_tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. target_tokens : Dict[str, torch.LongTensor], optional (default = None) Output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the target tokens are also represented as a ``TextField``. """ # (batch_size, input_sequence_length, encoder_output_dim) embedded_input = self._source_embedder(source_tokens) batch_size, _, _ = embedded_input.size() source_mask = get_text_field_mask(source_tokens) encoder_outputs = self._encoder(embedded_input, source_mask) final_encoder_output = encoder_outputs[:, -1] # (batch_size, encoder_output_dim) if target_tokens: targets = target_tokens["tokens"] target_sequence_length = targets.size()[1] # The last input from the target is either padding or the end symbol. Either way, we # don't have to process it. num_decoding_steps = target_sequence_length - 1 else: num_decoding_steps = self._max_decoding_steps decoder_hidden = final_encoder_output decoder_context = encoder_outputs.new_zeros(batch_size, self._decoder_output_dim) last_predictions = None step_logits = [] step_probabilities = [] step_predictions = [] for timestep in range(num_decoding_steps): if self.training and torch.rand(1).item() >= self._scheduled_sampling_ratio: input_choices = targets[:, timestep] else: if timestep == 0: # For the first timestep, when we do not have targets, we input start symbols. # (batch_size,) input_choices = source_mask.new_full((batch_size,), fill_value=self._start_index) else: input_choices = last_predictions decoder_input = self._prepare_decode_step_input(input_choices, decoder_hidden, encoder_outputs, source_mask) decoder_hidden, decoder_context = self._decoder_cell(decoder_input, (decoder_hidden, decoder_context)) # (batch_size, num_classes) output_projections = self._output_projection_layer(decoder_hidden) # list of (batch_size, 1, num_classes) step_logits.append(output_projections.unsqueeze(1)) class_probabilities = F.softmax(output_projections, dim=-1) _, predicted_classes = torch.max(class_probabilities, 1) step_probabilities.append(class_probabilities.unsqueeze(1)) last_predictions = predicted_classes # (batch_size, 1) step_predictions.append(last_predictions.unsqueeze(1)) # step_logits is a list containing tensors of shape (batch_size, 1, num_classes) # This is (batch_size, num_decoding_steps, num_classes) logits = torch.cat(step_logits, 1) class_probabilities = torch.cat(step_probabilities, 1) all_predictions = torch.cat(step_predictions, 1) output_dict = {"logits": logits, "class_probabilities": class_probabilities, "predictions": all_predictions} if target_tokens: target_mask = get_text_field_mask(target_tokens) loss = self._get_loss(logits, targets, target_mask) output_dict["loss"] = loss # TODO: Define metrics return output_dict
def forward(self, question, passage, span_start=None, span_end=None, metadata=None): ######## 1/2. Embedding Layer ######## # 2. After add embedding, pass the embedded vector to Highway network. embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None ######## 3. Contextual Embedding Layer ######## # Encode input vectors into new representation H and U by using a BiLSTM. # Shape: (batch_size, 2 * encoding dim, question_length) encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) # Shape: (batch_size, 2 * encoding dim, paragraph_length) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) # get each token dim by accessing the last dimention. encoding_dim = encoded_question.size(-1) ######## 4. Attention Flow Layer ######## # Calculate similarity matrix for attention layer. # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Create weighted vectors attended by context to query attention. # Calculate C2Q(context to query) attention, 'a' # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.last_dim_softmax( passage_question_similarity, question_mask) # Weighted vector by C2Q attentions. \hat{U}_:t \sum_j a_{tj} U_{:j} # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Create weighted vectors attended by query to context attention. # Replaced masked values to avoid let them affect the result. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Calculate Q2C(query to context) attention, 'b' # Shape: (batch_size, passage_length) question_passage_similarity = \ masked_similarity.max(dim=-1)[0].squeeze(-1) # Pass to the softmax layer. # Shape: (batch_size, passage_length) question_passage_attention = \ util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Merge attention vectors # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) ######## 5. Modeling Layer ######## # Model query-aware context vector by BiLSTM. modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) ######## 6. Output Layer ######## # Obtain the probability distribution of the start index. # Concat G(from attention flow layer) and M(from modeling layer) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Calculate the logits. # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Calculate Softmax (eq. 3) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Obtain the probability distribution of the end index. # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * # 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Obtain new modling representation based on start probalility. # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Concat attention flow output G and new modeling representation M2 # Shape: (batch_size, passage_length, encoding_dim * 4 + # span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) # Shape: (batch_size, passage_length) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) # Shape: (batch_size, passage_length) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) # Replace the masked values pretty small not to influence the final # results. span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } if span_start is not None: # Calculate the loss. # The training loss is the sum of the negative log probablities of # the true start and end indices by the predicted distributions, # everaged over all examples. loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) # Why need to be `torch.stack` self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].data.cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer(self._text_field_embedder(question)) embedded_passage = self._highway_layer(self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout(self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout(self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax(question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum(encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector], dim=-1) modeled_passage = self._dropout(self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout(torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze(1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout(self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout(torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = self.get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward(self, tokens: Dict[str, torch.Tensor], id: Any, answerers: Any, date: Any, accept_usr: Any) -> torch.Tensor: # n -- batch number # m -- author number # d -- hidden dimension # k -- skill number # l -- text length # p -- pos/neg author number in one batch mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) # (n, l, d) token_hidden = self.encoder(embeddings, mask).transpose(-1, -2) # (n, l, d) token_embed = torch.mean(token_hidden, 1).squeeze(1) # (n, d) # token_embed = token_hidden[:, :, -1] author_ctx_embed = self.ctx_attention(token_embed, self.author_embeddings, self.author_embeddings) # (n, m, d) # add layer norm for author context embedding author_ctx_embed = self.ctx_layer_norm(author_ctx_embed) # (n, m, d) # transfer the date into time embedding # TODO: use answer date for time embedding time_embed = gen_time_encoding(self.time_encoder, answerers, date, embeddings.size(2), self.num_authors, train_mode=self.training) # time_embed = [self.time_encoder.get_time_encoding(i) for i in date] # time_embed = torch.stack(time_embed, dim=0) # (n, d) # time_embed = time_embed.unsqueeze(1).expand(-1, self.num_authors, -1) # (n, m, d) author_ctx_embed_te = author_ctx_embed + time_embed author_tctx_embed = self.temp_ctx_attention(time_embed, author_ctx_embed, author_ctx_embed) # (n, m, d) # author_tctx_embed = self.temp_ctx_attention(author_ctx_embed_te, author_ctx_embed_te, author_ctx_embed_te) # (n, m, d) # get horizontal temporal time embeddings # htemp_embeds = [] # truth = [[j[0] for j in i] for i in answerers] # for i, d in enumerate(date): # pos_labels = br_utils.to_cuda(torch.tensor(truth[i])) # post_time_embeds = self.time_encoder.get_post_encodings(d) # (t, d) # post_time_embeds = post_time_embeds.unsqueeze(1).expand(-1, pos_labels.size(0), -1) # (t, pos, d) # # pos_embed = author_ctx_embed[i, pos_labels, :] # (pos, d) # pos_embed = pos_embed.unsqueeze(0).expand(post_time_embeds.size(0), -1, -1) # (t, pos, d) # author_post_ctx_embed_te = pos_embed + post_time_embeds # # author_post_ctx_embed = self.temp_ctx_attention(author_post_ctx_embed_te, author_post_ctx_embed_te, author_post_ctx_embed_te) # (t, pos, d) # author_post_ctx_embed = self.temp_ctx_attention(post_time_embeds, pos_embed, pos_embed) # (t, pos, d) # htemp_embeds.append(author_post_ctx_embed) # htemp_loss = self.htemp_loss(token_embed, htemp_embeds) # generate loss # loss, coherence = self.rank_loss(token_embed, author_tctx_embed, answerers) loss, coherence = self.rank_loss(token_embed, author_tctx_embed, answerers, accept_usr) # loss += 0.5 * htemp_loss # coherence = self.coherence_func(token_embed, None, author_tctx_embed) output = {"loss": loss, "coherence": coherence} predict = np.argsort(-coherence.detach().cpu().numpy(), axis=1) truth = [[j[0] for j in i] for i in answerers] # self.rank_recall(predict, truth) # self.mrr(predict, truth) self.mrr(predict, accept_usr) return output
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, p1_answer_marker: torch.IntTensor = None, p2_answer_marker: torch.IntTensor = None, p3_answer_marker: torch.IntTensor = None, yesno_list: torch.IntTensor = None, followup_list: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. p1_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 0. This is a tensor that has a shape [batch_size, max_qa_count, max_passage_length]. Most passage token will have assigned 'O', except the passage tokens belongs to the previous answer in the dialog, which will be assigned labels such as <1_start>, <1_in>, <1_end>. For more details, look into dataset_readers/util/make_reading_comprehension_instance_quac p2_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 1. It is similar to p1_answer_marker, but marking previous previous answer in passage. p3_answer_marker : ``torch.IntTensor``, optional This is one of the inputs, but only when num_context_answers > 2. It is similar to p1_answer_marker, but marking previous previous previous answer in passage. yesno_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (the yes/no/not a yes no question). followup_list : ``torch.IntTensor``, optional This is one of the outputs that we are trying to predict. Three way classification (followup / maybe followup / don't followup). metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question ID, original passage text, and token offsets into the passage for each instance in the batch. We use this for computing official metrics using the official SQuAD evaluation script. The length of this list should be the batch size, and each dictionary should have the keys ``id``, ``original_passage``, and ``token_offsets``. If you only want the best span string and don't care about official metrics, you can omit the ``id`` key. Returns ------- An output dictionary consisting of the followings. Each of the followings is a nested list because first iterates over dialog, then questions in dialog. qid : List[List[str]] A list of list, consisting of question ids. followup : List[List[int]] A list of list, consisting of continuation marker prediction index. (y :yes, m: maybe follow up, n: don't follow up) yesno : List[List[int]] A list of list, consisting of affirmation marker prediction index. (y :yes, x: not a yes/no question, n: np) best_span_str : List[List[str]] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ batch_size, max_qa_count, max_q_len, _ = question['token_characters'].size() total_qa_count = batch_size * max_qa_count qa_mask = torch.ge(followup_list, 0).view(total_qa_count) embedded_question = self._text_field_embedder(question, num_wrapping_dims=1) embedded_question = embedded_question.reshape(total_qa_count, max_q_len, self._text_field_embedder.get_output_dim()) embedded_question = self._variational_dropout(embedded_question) embedded_passage = self._variational_dropout(self._text_field_embedder(passage)) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question, num_wrapping_dims=1).float() question_mask = question_mask.reshape(total_qa_count, max_q_len) passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, max_qa_count, 1) repeated_passage_mask = repeated_passage_mask.view(total_qa_count, passage_length) if self._num_context_answers > 0: # Encode question turn number inside the dialog into question embedding. question_num_ind = util.get_range_vector(max_qa_count, util.get_device_of(embedded_question)) question_num_ind = question_num_ind.unsqueeze(-1).repeat(1, max_q_len) question_num_ind = question_num_ind.unsqueeze(0).repeat(batch_size, 1, 1) question_num_ind = question_num_ind.reshape(total_qa_count, max_q_len) question_num_marker_emb = self._question_num_marker(question_num_ind) embedded_question = torch.cat([embedded_question, question_num_marker_emb], dim=-1) # Encode the previous answers in passage embedding. repeated_embedded_passage = embedded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1). \ view(total_qa_count, passage_length, self._text_field_embedder.get_output_dim()) # batch_size * max_qa_count, passage_length, word_embed_dim p1_answer_marker = p1_answer_marker.view(total_qa_count, passage_length) p1_answer_marker_emb = self._prev_ans_marker(p1_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p1_answer_marker_emb], dim=-1) if self._num_context_answers > 1: p2_answer_marker = p2_answer_marker.view(total_qa_count, passage_length) p2_answer_marker_emb = self._prev_ans_marker(p2_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p2_answer_marker_emb], dim=-1) if self._num_context_answers > 2: p3_answer_marker = p3_answer_marker.view(total_qa_count, passage_length) p3_answer_marker_emb = self._prev_ans_marker(p3_answer_marker) repeated_embedded_passage = torch.cat([repeated_embedded_passage, p3_answer_marker_emb], dim=-1) repeated_encoded_passage = self._variational_dropout(self._phrase_layer(repeated_embedded_passage, repeated_passage_mask)) else: encoded_passage = self._variational_dropout(self._phrase_layer(embedded_passage, passage_mask)) repeated_encoded_passage = encoded_passage.unsqueeze(1).repeat(1, max_qa_count, 1, 1) repeated_encoded_passage = repeated_encoded_passage.view(total_qa_count, passage_length, self._encoding_dim) encoded_question = self._variational_dropout(self._phrase_layer(embedded_question, question_mask)) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_similarity = self._matrix_attention(repeated_encoded_passage, encoded_question) # Shape: (batch_size * max_qa_count, passage_length, question_length) passage_question_attention = util.masked_softmax(passage_question_similarity, question_mask) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values(passage_question_similarity, question_mask.unsqueeze(1), -1e7) question_passage_similarity = masked_similarity.max(dim=-1)[0].squeeze(-1) question_passage_attention = util.masked_softmax(question_passage_similarity, repeated_passage_mask) # Shape: (batch_size * max_qa_count, encoding_dim) question_passage_vector = util.weighted_sum(repeated_encoded_passage, question_passage_attention) tiled_question_passage_vector = question_passage_vector.unsqueeze(1).expand(total_qa_count, passage_length, self._encoding_dim) # Shape: (batch_size * max_qa_count, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([repeated_encoded_passage, passage_question_vectors, repeated_encoded_passage * passage_question_vectors, repeated_encoded_passage * tiled_question_passage_vector], dim=-1) final_merged_passage = F.relu(self._merge_atten(final_merged_passage)) residual_layer = self._variational_dropout(self._residual_encoder(final_merged_passage, repeated_passage_mask)) self_attention_matrix = self._self_attention(residual_layer, residual_layer) mask = repeated_passage_mask.reshape(total_qa_count, passage_length, 1) \ * repeated_passage_mask.reshape(total_qa_count, 1, passage_length) self_mask = torch.eye(passage_length, passage_length, device=self_attention_matrix.device) self_mask = self_mask.reshape(1, passage_length, passage_length) mask = mask * (1 - self_mask) self_attention_probs = util.masked_softmax(self_attention_matrix, mask) # (batch, passage_len, passage_len) * (batch, passage_len, dim) -> (batch, passage_len, dim) self_attention_vecs = torch.matmul(self_attention_probs, residual_layer) self_attention_vecs = torch.cat([self_attention_vecs, residual_layer, residual_layer * self_attention_vecs], dim=-1) residual_layer = F.relu(self._merge_self_attention(self_attention_vecs)) final_merged_passage = final_merged_passage + residual_layer # batch_size * maxqa_pair_len * max_passage_len * 200 final_merged_passage = self._variational_dropout(final_merged_passage) start_rep = self._span_start_encoder(final_merged_passage, repeated_passage_mask) span_start_logits = self._span_start_predictor(start_rep).squeeze(-1) end_rep = self._span_end_encoder(torch.cat([final_merged_passage, start_rep], dim=-1), repeated_passage_mask) span_end_logits = self._span_end_predictor(end_rep).squeeze(-1) span_yesno_logits = self._span_yesno_predictor(end_rep).squeeze(-1) span_followup_logits = self._span_followup_predictor(end_rep).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) # batch_size * maxqa_len_pair, max_document_len span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) best_span = self._get_best_span_yesno_followup(span_start_logits, span_end_logits, span_yesno_logits, span_followup_logits, self._max_span_length) output_dict: Dict[str, Any] = {} # Compute the loss. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, repeated_passage_mask), span_start.view(-1), ignore_index=-1) self._span_start_accuracy(span_start_logits, span_start.view(-1), mask=qa_mask) loss += nll_loss(util.masked_log_softmax(span_end_logits, repeated_passage_mask), span_end.view(-1), ignore_index=-1) self._span_end_accuracy(span_end_logits, span_end.view(-1), mask=qa_mask) self._span_accuracy(best_span[:, 0:2], torch.stack([span_start, span_end], -1).view(total_qa_count, 2), mask=qa_mask.unsqueeze(1).expand(-1, 2).long()) # add a select for the right span to compute loss gold_span_end_loc = [] span_end = span_end.view(total_qa_count).squeeze().data.cpu().numpy() for i in range(0, total_qa_count): gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 1, 0)) gold_span_end_loc.append(max(span_end[i] * 3 + i * passage_length * 3 + 2, 0)) gold_span_end_loc = span_start.new(gold_span_end_loc) pred_span_end_loc = [] for i in range(0, total_qa_count): pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 1, 0)) pred_span_end_loc.append(max(best_span[i][1] * 3 + i * passage_length * 3 + 2, 0)) predicted_end = span_start.new(pred_span_end_loc) _yesno = span_yesno_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, gold_span_end_loc).view(-1, 3) loss += nll_loss(F.log_softmax(_yesno, dim=-1), yesno_list.view(-1), ignore_index=-1) loss += nll_loss(F.log_softmax(_followup, dim=-1), followup_list.view(-1), ignore_index=-1) _yesno = span_yesno_logits.view(-1).index_select(0, predicted_end).view(-1, 3) _followup = span_followup_logits.view(-1).index_select(0, predicted_end).view(-1, 3) self._span_yesno_accuracy(_yesno, yesno_list.view(-1), mask=qa_mask) self._span_followup_accuracy(_followup, followup_list.view(-1), mask=qa_mask) output_dict["loss"] = loss # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] output_dict['followup'] = [] output_dict['yesno'] = [] best_span_cpu = best_span.detach().cpu().numpy() for i in range(batch_size): passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] f1_score = 0.0 per_dialog_best_span_list = [] per_dialog_yesno_list = [] per_dialog_followup_list = [] per_dialog_query_id_list = [] for per_dialog_query_index, (iid, answer_texts) in enumerate( zip(metadata[i]["instance_id"], metadata[i]["answer_texts_list"])): predicted_span = tuple(best_span_cpu[i * max_qa_count + per_dialog_query_index]) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] yesno_pred = predicted_span[2] followup_pred = predicted_span[3] per_dialog_yesno_list.append(yesno_pred) per_dialog_followup_list.append(followup_pred) per_dialog_query_id_list.append(iid) best_span_string = passage_str[start_offset:end_offset] per_dialog_best_span_list.append(best_span_string) if answer_texts: if len(answer_texts) > 1: t_f1 = [] # Compute F1 over N-1 human references and averages the scores. for answer_index in range(len(answer_texts)): idxes = list(range(len(answer_texts))) idxes.pop(answer_index) refs = [answer_texts[z] for z in idxes] t_f1.append(squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, refs)) f1_score = 1.0 * sum(t_f1) / len(t_f1) else: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, answer_texts) self._official_f1(100 * f1_score) output_dict['qid'].append(per_dialog_query_id_list) output_dict['best_span_str'].append(per_dialog_best_span_list) output_dict['yesno'].append(per_dialog_yesno_list) output_dict['followup'].append(per_dialog_followup_list) return output_dict
def forward( # type: ignore self, tokens: TextFieldTensors, verb_indicator: torch.LongTensor, tags: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters tokens : TextFieldTensors, required The output of `TextField.as_array()`, which should typically be passed directly to a `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer` tensors. At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens": Tensor(batch_size, num_tokens)}`. This dictionary will have the same keys as were used for the `TokenIndexers` when you created the `TextField` representing your sequence. The dictionary is designed to be passed directly to a `TextFieldEmbedder`, which knows how to combine different word representations into a single vector per token in your input. verb_indicator: torch.LongTensor, required. An integer `SequenceFeatureField` representation of the position of the verb in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be all zeros, in the case that the sentence has no verbal predicate. tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels of shape `(batch_size, num_tokens)` metadata : `List[Dict[str, Any]]`, optional, (default = None) metadata containg the original words in the sentence and the verb to compute the frame for, under 'words' and 'verb' keys, respectively. # Returns An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing unnormalised log probabilities of the tag classes. class_probabilities : torch.FloatTensor A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing a distribution of the tag classes per word. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_text_input = self.embedding_dropout( self.text_field_embedder(tokens)) mask = get_text_field_mask(tokens) embedded_verb_indicator = self.binary_feature_embedding( verb_indicator.long()) # Concatenate the verb feature onto the embedded text. This now # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim). embedded_text_with_verb_indicator = torch.cat( [embedded_text_input, embedded_verb_indicator], -1) batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size( ) encoded_text = self.encoder(embedded_text_with_verb_indicator, mask) logits = self.tag_projection_layer(encoded_text) reshaped_log_probs = logits.view(-1, self.num_classes) class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view( [batch_size, sequence_length, self.num_classes]) output_dict = { "logits": logits, "class_probabilities": class_probabilities } # We need to retain the mask in the output dictionary # so that we can crop the sequences to remove padding # when we do viterbi inference in self.decode. output_dict["mask"] = mask if tags is not None: loss = sequence_cross_entropy_with_logits( logits, tags, mask, label_smoothing=self._label_smoothing) if not self.ignore_span_metric and self.span_metric is not None and not self.training: batch_verb_indices = [ example_metadata["verb_index"] for example_metadata in metadata ] batch_sentences = [ example_metadata["words"] for example_metadata in metadata ] # Get the BIO tags from decode() # TODO (nfliu): This is kind of a hack, consider splitting out part # of decode() to a separate function. batch_bio_predicted_tags = self.decode(output_dict).pop("tags") batch_conll_predicted_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_predicted_tags ] batch_bio_gold_tags = [ example_metadata["gold_tags"] for example_metadata in metadata ] batch_conll_gold_tags = [ convert_bio_tags_to_conll_format(tags) for tags in batch_bio_gold_tags ] self.span_metric( batch_verb_indices, batch_sentences, batch_conll_predicted_tags, batch_conll_gold_tags, ) output_dict["loss"] = loss words, verbs = zip(*[(x["words"], x["verb"]) for x in metadata]) if metadata is not None: output_dict["words"] = list(words) output_dict["verb"] = list(verbs) return output_dict
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], valid_actions: List[List[ProductionRule]], action_sequence: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ We set up the initial state for the decoder, and pass that state off to either a DecoderTrainer, if we're training, or a BeamSearch for inference, if we're not. Parameters ---------- tokens : Dict[str, torch.LongTensor] The output of ``TextField.as_array()`` applied on the tokens ``TextField``. This will be passed through a ``TextFieldEmbedder`` and then through an encoder. valid_actions : ``List[List[ProductionRule]]`` A list of all possible actions for each ``World`` in the batch, indexed into a ``ProductionRule`` using a ``ProductionRuleField``. We will embed all of these and use the embeddings to determine which action to take at each timestep in the decoder. target_action_sequence : torch.Tensor, optional (default=None) The action sequence for the correct action sequence, where each action is an index into the list of possible actions. This tensor has shape ``(batch_size, sequence_length, 1)``. We remove the trailing dimension. sql_queries : List[List[str]], optional (default=None) A list of the SQL queries that are given during training or validation. """ embedded_utterance = self._utterance_embedder(tokens) mask = util.get_text_field_mask(tokens).float() batch_size = embedded_utterance.size(0) # (batch_size, num_tokens, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(embedded_utterance, mask)) initial_state = self._get_initial_state(encoder_outputs, mask, valid_actions) if action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). action_sequence = action_sequence.squeeze(-1) target_mask = action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, Any] = {} if action_sequence is not None: # target_action_sequence is of shape (batch_size, 1, target_sequence_length) # here after we unsqueeze it for the MML trainer. loss_output = self._decoder_trainer.decode(initial_state, self._transition_function, (action_sequence.unsqueeze(1), target_mask.unsqueeze(1))) outputs.update(loss_output) if not self.training: action_mapping = [] for batch_actions in valid_actions: batch_action_mapping = {} for action_index, action in enumerate(batch_actions): batch_action_mapping[action_index] = action[0] action_mapping.append(batch_action_mapping) outputs['action_mapping'] = action_mapping # This tells the state to start keeping track of debug info, which we'll pass along in # our output dictionary. initial_state.debug_info = [[] for _ in range(batch_size)] best_final_states = self._beam_search.search(self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=True) outputs['best_action_sequence'] = [] outputs['debug_info'] = [] outputs['predicted_sql_query'] = [] outputs['sql_queries'] = [] for i in range(batch_size): # Decoding may not have terminated with any completed valid SQL queries, if `num_steps` # isn't long enough (or if the model is not trained enough and gets into an # infinite action loop). if i not in best_final_states: self._exact_match(0) self._denotation_accuracy(0) self._valid_sql_query(0) self._action_similarity(0) outputs['predicted_sql_query'].append('') continue best_action_indices = best_final_states[i][0].action_history[0] action_strings = [action_mapping[i][action_index] for action_index in best_action_indices] predicted_sql_query = action_sequence_to_sql(action_strings) if action_sequence is not None: # Use a Tensor, not a Variable, to avoid a memory leak. targets = action_sequence[i].data sequence_in_targets = 0 sequence_in_targets = self._action_history_match(best_action_indices, targets) self._exact_match(sequence_in_targets) similarity = difflib.SequenceMatcher(None, best_action_indices, targets) self._action_similarity(similarity.ratio()) outputs['best_action_sequence'].append(action_strings) outputs['predicted_sql_query'].append(sqlparse.format(predicted_sql_query, reindent=True)) outputs['debug_info'].append(best_final_states[i][0].debug_info[0]) # type: ignore return outputs
def _get_initial_rnn_and_grammar_state(self, question: Dict[str, torch.LongTensor], table: Dict[str, torch.LongTensor], world: List[WikiTablesWorld], actions: List[List[ProductionRule]], outputs: Dict[str, Any]) -> Tuple[List[RnnStatelet], List[LambdaGrammarStatelet]]: """ Encodes the question and table, computes a linking between the two, and constructs an initial RnnStatelet and LambdaGrammarStatelet for each batch instance to pass to the decoder. We take ``outputs`` as a parameter here and `modify` it, adding things that we want to visualize in a demo. """ table_text = table['text'] # (batch_size, question_length, embedding_dim) embedded_question = self._question_embedder(question) question_mask = util.get_text_field_mask(question).float() # (batch_size, num_entities, num_entity_tokens, embedding_dim) embedded_table = self._question_embedder(table_text, num_wrapping_dims=1) table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float() batch_size, num_entities, num_entity_tokens, _ = embedded_table.size() num_question_tokens = embedded_question.size(1) # (batch_size, num_entities, embedding_dim) encoded_table = self._entity_encoder(embedded_table, table_mask) # (batch_size, num_entities, num_neighbors) neighbor_indices = self._get_neighbor_indices(world, num_entities, encoded_table) # Neighbor_indices is padded with -1 since 0 is a potential neighbor index. # Thus, the absolute value needs to be taken in the index_select, and 1 needs to # be added for the mask since that method expects 0 for padding. # (batch_size, num_entities, num_neighbors, embedding_dim) embedded_neighbors = util.batched_index_select(encoded_table, torch.abs(neighbor_indices)) neighbor_mask = util.get_text_field_mask({'ignored': neighbor_indices + 1}, num_wrapping_dims=1).float() # Encoder initialized to easily obtain a masked average. neighbor_encoder = TimeDistributed(BagOfEmbeddingsEncoder(self._embedding_dim, averaged=True)) # (batch_size, num_entities, embedding_dim) embedded_neighbors = neighbor_encoder(embedded_neighbors, neighbor_mask) # entity_types: tensor with shape (batch_size, num_entities), where each entry is the # entity's type id. # entity_type_dict: Dict[int, int], mapping flattened_entity_index -> type_index # These encode the same information, but for efficiency reasons later it's nice # to have one version as a tensor and one that's accessible on the cpu. entity_types, entity_type_dict = self._get_type_vector(world, num_entities, encoded_table) entity_type_embeddings = self._entity_type_encoder_embedding(entity_types) projected_neighbor_embeddings = self._neighbor_params(embedded_neighbors.float()) # (batch_size, num_entities, embedding_dim) entity_embeddings = torch.tanh(entity_type_embeddings + projected_neighbor_embeddings) # Compute entity and question word similarity. We tried using cosine distance here, but # because this similarity is the main mechanism that the model can use to push apart logit # scores for certain actions (like "n -> 1" and "n -> -1"), this needs to have a larger # output range than [-1, 1]. question_entity_similarity = torch.bmm(embedded_table.view(batch_size, num_entities * num_entity_tokens, self._embedding_dim), torch.transpose(embedded_question, 1, 2)) question_entity_similarity = question_entity_similarity.view(batch_size, num_entities, num_entity_tokens, num_question_tokens) # (batch_size, num_entities, num_question_tokens) question_entity_similarity_max_score, _ = torch.max(question_entity_similarity, 2) # (batch_size, num_entities, num_question_tokens, num_features) linking_features = table['linking'] linking_scores = question_entity_similarity_max_score if self._use_neighbor_similarity_for_linking: # The linking score is computed as a linear projection of two terms. The first is the # maximum similarity score over the entity's words and the question token. The second # is the maximum similarity over the words in the entity's neighbors and the question # token. # # The second term, projected_question_neighbor_similarity, is useful when a column # needs to be selected. For example, the question token might have no similarity with # the column name, but is similar with the cells in the column. # # Note that projected_question_neighbor_similarity is intended to capture the same # information as the related_column feature. # # Also note that this block needs to be _before_ the `linking_params` block, because # we're overwriting `linking_scores`, not adding to it. # (batch_size, num_entities, num_neighbors, num_question_tokens) question_neighbor_similarity = util.batched_index_select(question_entity_similarity_max_score, torch.abs(neighbor_indices)) # (batch_size, num_entities, num_question_tokens) question_neighbor_similarity_max_score, _ = torch.max(question_neighbor_similarity, 2) projected_question_entity_similarity = self._question_entity_params( question_entity_similarity_max_score.unsqueeze(-1)).squeeze(-1) projected_question_neighbor_similarity = self._question_neighbor_params( question_neighbor_similarity_max_score.unsqueeze(-1)).squeeze(-1) linking_scores = projected_question_entity_similarity + projected_question_neighbor_similarity feature_scores = None if self._linking_params is not None: feature_scores = self._linking_params(linking_features).squeeze(3) linking_scores = linking_scores + feature_scores # (batch_size, num_question_tokens, num_entities) linking_probabilities = self._get_linking_probabilities(world, linking_scores.transpose(1, 2), question_mask, entity_type_dict) # (batch_size, num_question_tokens, embedding_dim) link_embedding = util.weighted_sum(entity_embeddings, linking_probabilities) encoder_input = torch.cat([link_embedding, embedded_question], 2) # (batch_size, question_length, encoder_output_dim) encoder_outputs = self._dropout(self._encoder(encoder_input, question_mask)) # This will be our initial hidden state and memory cell for the decoder LSTM. final_encoder_output = util.get_final_encoder_states(encoder_outputs, question_mask, self._encoder.is_bidirectional()) memory_cell = encoder_outputs.new_zeros(batch_size, self._encoder.get_output_dim()) # To make grouping states together in the decoder easier, we convert the batch dimension in # all of our tensors into an outer list. For instance, the encoder outputs have shape # `(batch_size, question_length, encoder_output_dim)`. We need to convert this into a list # of `batch_size` tensors, each of shape `(question_length, encoder_output_dim)`. Then we # won't have to do any index selects, or anything, we'll just do some `torch.cat()`s. encoder_output_list = [encoder_outputs[i] for i in range(batch_size)] question_mask_list = [question_mask[i] for i in range(batch_size)] initial_rnn_state = [] for i in range(batch_size): initial_rnn_state.append(RnnStatelet(final_encoder_output[i], memory_cell[i], self._first_action_embedding, self._first_attended_question, encoder_output_list, question_mask_list)) initial_grammar_state = [self._create_grammar_state(world[i], actions[i], linking_scores[i], entity_types[i]) for i in range(batch_size)] if not self.training: # We add a few things to the outputs that will be returned from `forward` at evaluation # time, for visualization in a demo. outputs['linking_scores'] = linking_scores if feature_scores is not None: outputs['feature_scores'] = feature_scores outputs['similarity_scores'] = question_entity_similarity_max_score return initial_rnn_state, initial_grammar_state
def forward(self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None # pylint:disable=unused-argument ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() # apply dropout for LSTM if self.rnn_input_dropout: embedded_premise = self.rnn_input_dropout(embedded_premise) embedded_hypothesis = self.rnn_input_dropout(embedded_hypothesis) # encode premise and hypothesis encoded_premise = self._encoder(embedded_premise, premise_mask) encoded_hypothesis = self._encoder(embedded_hypothesis, hypothesis_mask) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(encoded_premise, encoded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = last_dim_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(encoded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = last_dim_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(encoded_premise, h2p_attention) # the "enhancement" layer premise_enhanced = torch.cat( [encoded_premise, attended_hypothesis, encoded_premise - attended_hypothesis, encoded_premise * attended_hypothesis], dim=-1 ) hypothesis_enhanced = torch.cat( [encoded_hypothesis, attended_premise, encoded_hypothesis - attended_premise, encoded_hypothesis * attended_premise], dim=-1 ) # The projection layer down to the model dimension. Dropout is not applied before # projection. projected_enhanced_premise = self._projection_feedforward(premise_enhanced) projected_enhanced_hypothesis = self._projection_feedforward(hypothesis_enhanced) # Run the inference layer if self.rnn_input_dropout: projected_enhanced_premise = self.rnn_input_dropout(projected_enhanced_premise) projected_enhanced_hypothesis = self.rnn_input_dropout(projected_enhanced_hypothesis) v_ai = self._inference_encoder(projected_enhanced_premise, premise_mask) v_bi = self._inference_encoder(projected_enhanced_hypothesis, hypothesis_mask) # The pooling layer -- max and avg pooling. # (batch_size, model_dim) v_a_max, _ = replace_masked_values( v_ai, premise_mask.unsqueeze(-1), -1e7 ).max(dim=1) v_b_max, _ = replace_masked_values( v_bi, hypothesis_mask.unsqueeze(-1), -1e7 ).max(dim=1) v_a_avg = torch.sum(v_ai * premise_mask.unsqueeze(-1), dim=1) / torch.sum( premise_mask, 1, keepdim=True ) v_b_avg = torch.sum(v_bi * hypothesis_mask.unsqueeze(-1), dim=1) / torch.sum( hypothesis_mask, 1, keepdim=True ) # Now concat # (batch_size, model_dim * 2 * 4) v_all = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) # the final MLP -- apply dropout to input, and MLP applies to output & hidden if self.dropout: v_all = self.dropout(v_all) output_hidden = self._output_feedforward(v_all) label_logits = self._output_logit(output_hidden) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss return output_dict
def forward(self, text, spans, ner_labels, coref_labels, relation_labels, trigger_labels, argument_labels, metadata, span_labels): # Shape: (batch_size, max_sentence_length, bert_size) text_embeddings = self._text_field_embedder(text) text_embeddings = self._lexical_dropout(text_embeddings) # Shape: (batch_size, max_sentence_length) text_mask = util.get_text_field_mask(text).float() sentence_lengths = 0*text_mask.sum(dim=1).long() for i in range(len(metadata)): sentence_lengths[i] = metadata[i]["end_ix"] - metadata[i]["start_ix"] # Shape: (batch_size, max_sentence_length, encoding_dim) contextualized_embeddings = self._lstm_dropout(self._context_layer(text_embeddings, text_mask)) assert spans.max() < contextualized_embeddings.shape[1] span_mask = (spans[:, :, 0] >= 0).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans, text_mask) # Make calls out to the modules to get results. output_ner = {'loss': 0} output_span = {'loss': 0} # Make predictions and compute losses for each module if self._loss_weights['span'] > 0: output_span = self._span(spans, span_mask, span_embeddings, sentence_lengths, span_labels, metadata) if self._loss_weights['ner'] > 0: output_ner = self._ner( spans, span_mask, span_embeddings, sentence_lengths, ner_labels, metadata, output_span) loss = ( self._loss_weights['ner'] * output_ner['loss'] + self._loss_weights['span'] * output_span['loss'] ) output_dict = dict(ner=output_ner, span=output_span) output_dict['loss'] = loss return output_dict
def forward( self, sentence: Dict[str, torch.Tensor], labels: torch.Tensor = None, labels_aspect: torch.Tensor = None, domain: torch.Tensor = None, sample_weight: torch.Tensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: pos_weight = None if 'label_prior' in metadata[0]: # prior on labels pos_weight = torch.from_numpy(metadata[0]['label_prior']).float() if 'label_prior_aspect' in metadata[0]: # prior on labels pos_weight_aspect = [None for _ in range(2)] for dom_idx in range(2): try: idx = int((domain == dom_idx).nonzero()[0]) except IndexError: continue w = metadata[idx]['label_prior_aspect'] pos_weight_aspect[dom_idx] = torch.from_numpy(w).float() labels = labels[:, :NUM_MOTIVES] # Delete ``NULL'' column embeddings = self.word_embedder(sentence) mask = get_text_field_mask(sentence) encoder_out = self.encoder(embeddings, mask) label_logits = self.classifier_feedforward(encoder_out) output = {'label_logits': label_logits} if labels is not None: label_ids = (label_logits.sign().long() + 1) / 2 self._macro_f1(label_ids, labels) output['loss'] = self._loss_func(label_logits, labels.float(), weight=sample_weight, pos_weight=pos_weight) if labels_aspect is not None: components = [(self.classifier_aspect_res, self._f1_aspect_res), (self.classifier_aspect_lap, self._f1_aspect_lap)] for dom_idx, (clf, f1) in enumerate(components): h = encoder_out[domain == dom_idx] # encoded sentences if h.shape[0] == 0: # no instance continue aspect_logits = clf(h) aspects_pred = (aspect_logits.sign().long() + 1) / 2 aspects = labels_aspect[domain == dom_idx][:, :aspect_logits.shape[1]] f1(aspects_pred, aspects) output['loss'] += self._loss_func( aspect_logits, aspects.float(), weight=sample_weight[domain == dom_idx], pos_weight=pos_weight_aspect[dom_idx]) return output
def forward( self, messages: Dict[str, torch.Tensor], # (batch_size, n_turns, n_facts, n_words) facts: Dict[str, torch.Tensor], # (batch_size, n_turns) senders: torch.Tensor, # (batch_size, n_turns, n_acts) dialog_acts: torch.Tensor, # (batch_size, n_turns) dialog_acts_mask: torch.Tensor, # (batch_size, n_entities) known_entities: Dict[str, torch.Tensor], # (batch_size, 1) focus_entity: Dict[str, torch.Tensor], # (batch_size, n_turns, n_facts) fact_labels: Optional[torch.Tensor] = None, # (batch_size, n_turns, 2) likes: Optional[torch.Tensor] = None, metadata: Optional[Dict] = None, ): output = {} # Take care of the easy stuff first # (batch_size, n_entities) known_entities_mask = get_text_field_mask(known_entities) # (batch_size, n_turns, sender_emb_size) sender_emb = self._sender_emb(senders) known_emb = self._mention_embedder(known_entities) # TODO: This could instead of averaged, be attended known_vec = self._known_net( masked_mean(known_emb, known_entities_mask.unsqueeze(-1), dim=1)) # There is always exactly one entity focus_emb = self._focus_net( self._mention_embedder(focus_entity)[:, 0, :]) if self._use_bert: # (batch_size, n_turns, n_words, emb_dim) context, utter_mask = self._bert_encoder(messages) context = self._dropout(context) else: # (batch_size, n_turns) # This is the mask since not all dialogs have same number # of turns utter_mask = get_text_field_mask(messages) # (batch_size, n_turns, n_words) # Mask since not all utterances have same number of words # Wrapping dim skips over n_messages dim text_mask = get_text_field_mask(messages, num_wrapping_dims=1) # (batch_size, n_turns, n_words, emb_dim) embed = self._dropout(self._utter_embedder(messages)) # (batch_size, n_turns, hidden_dim) context = self._dist_utter_context(embed, text_mask) # (batch_size, n_turns, act_emb_size) act_emb = self._act_embedder(dialog_acts.float()) act_emb = self._clamp_dialog_acts(act_emb) # (batch_size, n_turns, hidden_dim + known_dim + focus_dim + sender_dim + act_dim) n_turns = context.shape[1] full_context = torch.cat( ( context, sender_emb, act_emb, focus_emb[:, None, :].repeat_interleave(n_turns, 1), known_vec[:, None, :].repeat_interleave(n_turns, 1), ), dim=-1, ) # (batch_size, n_turns, hidden_dim) # This assumes dialog_context does not peek into future dialog_context = self._dialog_context(full_context, utter_mask) # shift context one right, pad with zeros at front # This makes it so that utter_t is paired with context_t-1 # which is what we want # This is useful in a few different places, so compute it here once shape = dialog_context.shape shifted_context = torch.cat( ( dialog_context.new_zeros([shape[0], 1, shape[2]]), dialog_context[:, :-1, :], ), dim=1, ) has_loss = False if self._disable_dialog_acts: da_loss = 0 policy_loss = 0 else: # Dialog act per utter loss has_loss = True da_loss = self._compute_da_loss( output, context, shifted_context, utter_mask, dialog_acts, dialog_acts_mask, ) # Policy loss policy_loss = self._compute_policy_loss(output, shifted_context, utter_mask, dialog_acts, dialog_acts_mask) if self._disable_facts: # If facts are disabled, don't output anything related # to them fact_loss = 0 else: if self._use_bert: # (batch_size, n_turns, n_words, emb_dim) fact_repr, fact_mask = self._bert_encoder(facts) fact_repr = self._dropout(fact_repr) fact_mask[:, ::2] = 0 else: # (batch_size, n_turns, n_facts) # Wrapping dim skips over n_messages fact_mask = get_text_field_mask(facts, num_wrapping_dims=1) # In addition to masking padded facts, also explicitly mask # user turns just in case fact_mask[:, ::2] = 0 # (batch_size, n_turns, n_facts, n_words) # Wrapping dim skips over n_turns and n_facts fact_text_mask = get_text_field_mask(facts, num_wrapping_dims=2) # (batch_size, n_turns, n_facts, n_words, emb_dim) # Share encoder with utter encoder # Again, stupid dimensions fact_embed = self._dropout(self._utter_embedder(facts)) shape = fact_embed.shape word_dim = shape[-2] emb_dim = shape[-1] reshaped_facts = fact_embed.view(-1, word_dim, emb_dim) reshaped_fact_text_mask = fact_text_mask.view(-1, word_dim) reshaped_fact_repr = self._utter_context( reshaped_facts, reshaped_fact_text_mask) # No more emb dimension or word/seq dim fact_repr = reshaped_fact_repr.view(shape[:-2] + (-1, )) fact_logits = self._fact_ranker( shifted_context, fact_repr, ) output["fact_logits"] = fact_logits if fact_labels is not None: has_loss = True fact_loss = self._compute_fact_loss(fact_logits, fact_labels, fact_mask) self._fact_loss_metric(fact_loss.item()) self._fact_mrr(fact_logits, fact_labels, mask=fact_mask) else: fact_loss = 0 if self._disable_likes: like_loss = 0 else: has_loss = True # (batch_size, n_turns, 2) like_logits = self._like_classifier(dialog_context) output["like_logits"] = like_logits # There are several masks here to get the loss/metrics correct # - utter_mask: mask out positions that do not have an utterance # - user_mask: mask out positions that have a user utterances # since their turns are never liked # Using new_ones() preserves the type of the tensor user_mask = utter_mask.new_ones(utter_mask.shape) # Since the user is always even, this masks out user positions user_mask[:, ::2] = 0 final_mask = utter_mask * user_mask masked_likes = likes * final_mask if likes is not None: has_loss = True like_loss = sequence_cross_entropy_with_logits( like_logits, masked_likes, final_mask) self._like_accuracy(like_logits, masked_likes, final_mask) self._like_loss_metric(like_loss.item()) else: like_loss = 0 if has_loss: output["loss"] = (self._fact_loss_weight * fact_loss + like_loss + da_loss + policy_loss) return output
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], spans: torch.LongTensor, metadata: List[Dict[str, Any]], pos_tags: Dict[str, torch.LongTensor] = None, span_labels: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. spans : ``torch.LongTensor``, required. A tensor of shape ``(batch_size, num_spans, 2)`` representing the inclusive start and end indices of all possible spans in the sentence. metadata : List[Dict[str, Any]], required. A dictionary of metadata for each batch element which has keys: tokens : ``List[str]``, required. The original string tokens in the sentence. gold_tree : ``nltk.Tree``, optional (default = None) Gold NLTK trees for use in evaluation. pos_tags : ``List[str]``, optional. The POS tags for the sentence. These can be used in the model as embedded features, but they are passed here in addition for use in constructing the tree. pos_tags : ``torch.LongTensor``, optional (default = None) The output of a ``SequenceLabelField`` containing POS tags. span_labels : ``torch.LongTensor``, optional (default = None) A torch tensor representing the integer gold class labels for all possible spans, of shape ``(batch_size, num_spans)``. Returns ------- An output dictionary consisting of: class_probabilities : ``torch.FloatTensor`` A tensor of shape ``(batch_size, num_spans, span_label_vocab_size)`` representing a distribution over the label classes per span. spans : ``torch.LongTensor`` The original spans tensor. tokens : ``List[List[str]]``, required. A list of tokens in the sentence for each element in the batch. pos_tags : ``List[List[str]]``, required. A list of POS tags in the sentence for each element in the batch. num_spans : ``torch.LongTensor``, required. A tensor of shape (batch_size), representing the lengths of non-padded spans in ``enumerated_spans``. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ embedded_text_input = self.text_field_embedder(tokens) if pos_tags is not None and self.pos_tag_embedding is not None: embedded_pos_tags = self.pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self.pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(tokens) # Looking at the span start index is enough to know if # this is padding or not. Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).long() if span_mask.dim() == 1: # This happens if you use batch_size 1 and encounter # a length 1 sentence in PTB, which do exist. -.- span_mask = span_mask.unsqueeze(-1) if span_labels is not None and span_labels.dim() == 1: span_labels = span_labels.unsqueeze(-1) num_spans = get_lengths_from_binary_sequence_mask(span_mask) encoded_text = self.encoder(embedded_text_input, mask) span_representations = self.span_extractor(encoded_text, spans, mask, span_mask) if self.feedforward_layer is not None: span_representations = self.feedforward_layer(span_representations) logits = self.tag_projection_layer(span_representations) class_probabilities = last_dim_softmax(logits, span_mask.unsqueeze(-1)) output_dict = { "class_probabilities": class_probabilities, "spans": spans, "tokens": [meta["tokens"] for meta in metadata], "pos_tags": [meta.get("pos_tags") for meta in metadata], "num_spans": num_spans } if span_labels is not None: loss = sequence_cross_entropy_with_logits(logits, span_labels, span_mask) self.tag_accuracy(class_probabilities, span_labels, span_mask) output_dict["loss"] = loss # The evalb score is expensive to compute, so we only compute # it for the validation and test sets. batch_gold_trees = [meta.get("gold_tree") for meta in metadata] if all(batch_gold_trees) and self._evalb_score is not None and not self.training: gold_pos_tags: List[List[str]] = [list(zip(*tree.pos()))[1] for tree in batch_gold_trees] predicted_trees = self.construct_trees(class_probabilities.cpu().data, spans.cpu().data, num_spans.data, output_dict["tokens"], gold_pos_tags) self._evalb_score(predicted_trees, batch_gold_trees) return output_dict
def forward(self, tokens: Dict[str, torch.Tensor], label: torch.Tensor = None, task_id=None) -> Dict[str, torch.Tensor]: if (task_id is None): task_id = self.get_current_taskid() hidden2tag = self.classification_layers[task_id] if torch.cuda.is_available(): tokens = move_to_device(tokens, torch.cuda.current_device()) mask = get_text_field_mask(tokens) embeddings = self.word_embeddings(tokens) if self.args.position_embed: embeddings = self.pos_embedding(embeddings) # Task embedder add task id if self.task_embedder: bs, seq, edi = embeddings.shape task_em = torch.randn((bs, seq, 1)) task_em.fill_(self.get_current_taskid()) if torch.cuda.is_available(): task_em = move_to_device(task_em, torch.cuda.current_device()) embeddings = torch.cat([embeddings, task_em], dim=-1) # Task encoding adds the id using transformer style. if self.task_encoder: embeddings = self.task_encoder(embeddings, self.get_current_taskid()) # Task projection adds the id using single linear layer projection from one hot encoding if self.task_projection: embeddings = self.task_projection(embeddings, self.get_current_taskid()) # Increase temperature at embedding layer if (self.args.all_temp or self.args.emb_temp) and self.inv_temp: embeddings = self.inv_temp * embeddings if type(self.encoder) == HashedMemoryRNN: output = self.encoder(embeddings, mask, mem_tokens=tokens) else: output = self.encoder(embeddings, mask) if type(output) == tuple: encoder_out, activations = output activations = encoder_out else: encoder_out = output activations = output self.activations = activations self.labels = label # Increase temperature at encoder layer if (self.args.all_temp or self.args.enc_temp) and self.inv_temp: encoder_out = self.inv_temp * encoder_out if self.use_task_memory: encoder_out = self.task_memory(encoder_out, self.current_task) tag_logits = hidden2tag(encoder_out) # Increase temperature softmax layer if (self.args.all_temp or self.args.softmax_temp) and self.inv_temp: tag_logits = self.inv_temp * tag_logits output = {'logits': tag_logits, 'encoder_output': encoder_out} if label is not None: _, preds = tag_logits.max(dim=1) self.average( matthews_corrcoef(label.data.cpu().numpy(), preds.data.cpu().numpy())) self.micro_avg( f1_score(label.data.cpu().numpy(), preds.data.cpu().numpy(), average='micro')) self.accuracy(tag_logits, label) output["loss"] = self.loss_function(tag_logits, label) if self.use_task_memory: output["loss"] += self.task_memory.get_memory_loss( self.current_task) if (self.args.ewc or self.args.oewc) and self.training: output[ "ewc_loss"] = self.args.ewc_importance * self.ewc.penalty( self.get_current_taskid()) output["loss"] += output["ewc_loss"] output["loss"].backward(retain_graph=True) if self._len_dataset: self.ewc.update_penalty(self.task2id[self.current_task], self, self._len_dataset) return output
def forward( self, # type: ignore utterance: Dict[str, torch.LongTensor], logical_forms: Dict[str, torch.LongTensor], utterance_string: List[str], logical_form_strings: List[List[str]]) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- Returns ------- """ # (batch_size, num_utterance_tokens, utterance_embedding_dim) embedded_utterance = self.utterance_embedder(utterance) utterance_mask = util.get_text_field_mask(utterance) encoded_utterance = self.utterance_encoder(embedded_utterance, utterance_mask) # Because we're just summing everything in the end, we can do the sum upfront to save some # time. # (batch_size, utterance_embedding_dim) encoded_utterance = encoded_utterance.sum(dim=1) # (batch_size, num_logical_forms, num_lf_tokens, lf_embedding_dim) embedded_logical_forms = self.logical_form_embedder( logical_forms, num_wrapping_dims=1) # (batch_size, num_logical_forms, num_lf_tokens) logical_form_token_mask = util.get_text_field_mask(logical_forms, num_wrapping_dims=1) # (batch_size, num_logical_forms) logical_form_mask = logical_form_token_mask.sum(dim=-1).clamp(max=1) # (batch_size, num_logical_forms, lf_embedding_dim) encoded_logical_forms = embedded_logical_forms.sum(dim=2) # (batch_size, num_logical_forms, utterance_embedding_dim) predicted_embeddings = self.translation_layer(encoded_logical_forms) # (batch_size, num_logical_forms) similarities = torch.nn.functional.cosine_similarity( predicted_embeddings, encoded_utterance.unsqueeze(1), dim=2) # to avoid division by zero, add a 1 logical_form_lens = 1.0 + torch.sum( logical_form_token_mask, dim=-1, dtype=similarities.dtype) if self.normalize_by_len: similarities = similarities / logical_form_lens # Make sure masked logical forms aren't included in the max. similarities = util.replace_masked_values(similarities, logical_form_mask, -1e7) ranks = (similarities[:, 0].unsqueeze(1) < similarities) curr_ranks = ranks.sum(dim=-1) # (32,) ranks hits = [(curr_ranks < k).sum().cpu().data.numpy() for k in [3, 5, 10]] self.hits3 += hits[0] self.hits5 += hits[1] self.hits10 += hits[2] self.mean_ranks += curr_ranks.sum(dim=0).cpu().data.numpy() self.batches += ranks.shape[0] max_similarity, most_similar = similarities.max(dim=-1) loss = (1 - max_similarity).sum() self.accuracy += (most_similar == 0).sum().cpu().data.numpy() most_similar_strings = [] for instance_most_similar, instance_logical_forms in zip( most_similar.tolist(), logical_form_strings): most_similar_strings.append( instance_logical_forms[instance_most_similar]) return { "loss": loss, "most_similar": most_similar_strings, "utterance": utterance_string, "all_similarities": similarities }
def forward( self, # type: ignore tokens: TextFieldTensors, tags: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, torch.Tensor]: """ # Parameters tokens : ``TextFieldTensors``, required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is : ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. tags : ``torch.LongTensor``, optional (default = ``None``) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) metadata containg the original words in the sentence to be tagged under a 'words' key. # Returns An output dictionary consisting of: logits : ``torch.FloatTensor`` The logits that are the output of the ``tag_projection_layer`` mask : ``torch.LongTensor`` The text field mask for the input tokens tags : ``List[List[int]]`` The predicted tags using the Viterbi algorithm. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. Only computed if gold label ``tags`` are provided. """ embedded_text_input = self.text_field_embedder(tokens) mask = util.get_text_field_mask(tokens) if self.dropout: embedded_text_input = self.dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) if self.dropout: encoded_text = self.dropout(encoded_text) if self._feedforward is not None: encoded_text = self._feedforward(encoded_text) logits = self.tag_projection_layer(encoded_text) best_paths = self.crf.viterbi_tags(logits, mask, top_k=self.top_k) # Just get the top tags and ignore the scores. predicted_tags = cast(List[List[int]], [x[0][0] for x in best_paths]) output = {"logits": logits, "mask": mask, "tags": predicted_tags} if self.top_k > 1: output["top_k_tags"] = best_paths if tags is not None: # Add negative log-likelihood as loss log_likelihood = self.crf(logits, tags, mask) output["loss"] = -log_likelihood # Represent viterbi tags as "class probabilities" that we can # feed into the metrics class_probabilities = logits * 0.0 for i, instance_tags in enumerate(predicted_tags): for j, tag_id in enumerate(instance_tags): class_probabilities[i, j, tag_id] = 1 for metric in self.metrics.values(): metric(class_probabilities, tags, mask.float()) if self.calculate_span_f1: self._f1_metric(class_probabilities, tags, mask.float()) if metadata is not None: output["words"] = [x["words"] for x in metadata] return output
def forward(self, # type: ignore premise: Dict[str, torch.LongTensor], hypothesis: Dict[str, torch.LongTensor], label: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- premise : Dict[str, torch.LongTensor] From a ``TextField`` hypothesis : Dict[str, torch.LongTensor] From a ``TextField`` label : torch.IntTensor, optional, (default = None) From a ``LabelField`` metadata : ``List[Dict[str, Any]]``, optional, (default = None) Metadata containing the original tokenization of the premise and hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively. Returns ------- An output dictionary consisting of: label_logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the entailment label. label_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the entailment label. loss : torch.FloatTensor, optional A scalar loss to be optimised. """ embedded_premise = self._text_field_embedder(premise) embedded_hypothesis = self._text_field_embedder(hypothesis) premise_mask = get_text_field_mask(premise).float() hypothesis_mask = get_text_field_mask(hypothesis).float() if self._premise_encoder: embedded_premise = self._premise_encoder(embedded_premise, premise_mask) if self._hypothesis_encoder: embedded_hypothesis = self._hypothesis_encoder(embedded_hypothesis, hypothesis_mask) projected_premise = self._attend_feedforward(embedded_premise) projected_hypothesis = self._attend_feedforward(embedded_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) similarity_matrix = self._matrix_attention(projected_premise, projected_hypothesis) # Shape: (batch_size, premise_length, hypothesis_length) p2h_attention = masked_softmax(similarity_matrix, hypothesis_mask) # Shape: (batch_size, premise_length, embedding_dim) attended_hypothesis = weighted_sum(embedded_hypothesis, p2h_attention) # Shape: (batch_size, hypothesis_length, premise_length) h2p_attention = masked_softmax(similarity_matrix.transpose(1, 2).contiguous(), premise_mask) # Shape: (batch_size, hypothesis_length, embedding_dim) attended_premise = weighted_sum(embedded_premise, h2p_attention) premise_compare_input = torch.cat([embedded_premise, attended_hypothesis], dim=-1) hypothesis_compare_input = torch.cat([embedded_hypothesis, attended_premise], dim=-1) compared_premise = self._compare_feedforward(premise_compare_input) compared_premise = compared_premise * premise_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_premise = compared_premise.sum(dim=1) compared_hypothesis = self._compare_feedforward(hypothesis_compare_input) compared_hypothesis = compared_hypothesis * hypothesis_mask.unsqueeze(-1) # Shape: (batch_size, compare_dim) compared_hypothesis = compared_hypothesis.sum(dim=1) aggregate_input = torch.cat([compared_premise, compared_hypothesis], dim=-1) label_logits = self._aggregate_feedforward(aggregate_input) label_probs = torch.nn.functional.softmax(label_logits, dim=-1) output_dict = {"label_logits": label_logits, "label_probs": label_probs, "h2p_attention": h2p_attention, "p2h_attention": p2h_attention} if label is not None: loss = self._loss(label_logits, label.long().view(-1)) self._accuracy(label_logits, label) output_dict["loss"] = loss if metadata is not None: output_dict["premise_tokens"] = [x["premise_tokens"] for x in metadata] output_dict["hypothesis_tokens"] = [x["hypothesis_tokens"] for x in metadata] return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, passage_sem_views_q: torch.IntTensor = None, passage_sem_views_k: torch.IntTensor = None, question_sem_views_q: torch.IntTensor = None, question_sem_views_k: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. passage_sem_views_q : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Query (Q) passage_sem_views_k : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Key (K) question_sem_views_q : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Query (Q) question_sem_views_k : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Key (K) metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ return_output_metadata = self.return_output_metadata question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() if isinstance(self._phrase_layer, QaNetSemanticFlatEncoder) \ or isinstance(self._phrase_layer, QaNetSemanticFlatConcatEncoder)\ or isinstance(self._modeling_layer, QaNetSemanticFlatEncoder) \ or isinstance(self._modeling_layer, QaNetSemanticFlatConcatEncoder): if passage_sem_views_q is not None: passage_sem_views_q = passage_sem_views_q.long() if passage_sem_views_k is not None: passage_sem_views_k = passage_sem_views_k.long() if question_sem_views_q is not None: question_sem_views_q = question_sem_views_q.long() if question_sem_views_k is not None: question_sem_views_k = question_sem_views_k.long() if torch.cuda.is_available(): # indices question_mask = to_cuda(question_mask, move_to_cuda=True) passage_mask = to_cuda(passage_mask, move_to_cuda=True) question = { k: to_cuda(v, move_to_cuda=True) for k, v in question.items() } passage = { k: to_cuda(v, move_to_cuda=True) for k, v in passage.items() } # span if span_start is not None: span_start = to_cuda(span_start, move_to_cuda=True) if span_end is not None: span_end = to_cuda(span_end, move_to_cuda=True) # semantic views if passage_sem_views_q is not None: passage_sem_views_q = to_cuda(passage_sem_views_q, move_to_cuda=True) if passage_sem_views_k is not None: passage_sem_views_k = to_cuda(passage_sem_views_k, move_to_cuda=True) if question_sem_views_q is not None: question_sem_views_q = to_cuda(question_sem_views_q, move_to_cuda=True) if question_sem_views_k is not None: question_sem_views_k = to_cuda(question_sem_views_k, move_to_cuda=True) embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer( self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer( self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer( embedded_question) projected_embedded_passage = self._encoding_proj_layer( embedded_passage) encoded_passage_output_metadata = None encoded_question_output_metadata = None if isinstance(self._phrase_layer, QaNetSemanticFlatEncoder) \ or isinstance(self._phrase_layer, QaNetSemanticFlatConcatEncoder): if is_output_meta_supported(self._phrase_layer): encoded_passage, encoded_passage_output_metadata = self._phrase_layer( projected_embedded_passage, passage_sem_views_q, passage_sem_views_k, passage_mask, return_output_metadata) encoded_passage = self._dropout(encoded_passage) encoded_question, encoded_question_output_metadata = self._phrase_layer( projected_embedded_question, question_sem_views_q, question_sem_views_k, question_mask, return_output_metadata) encoded_question = self._dropout(encoded_question) else: encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_sem_views_q, passage_sem_views_k, passage_mask)) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_sem_views_q, question_sem_views_k, question_mask)) else: encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_mask)) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors ], dim=-1)) modeled_passage_list = [ self._modeling_proj_layer(merged_passage_attention_vectors) ] modeled_passage_output_metadata_list = {} for modeling_layer_id in range(3): modeled_passage_output_metadata = None if isinstance(self._modeling_layer, QaNetSemanticFlatEncoder) \ or isinstance(self._modeling_layer, QaNetSemanticFlatConcatEncoder): if is_output_meta_supported(self._modeling_layer): modeled_passage, modeled_passage_output_metadata = self._modeling_layer( modeled_passage_list[-1], passage_sem_views_q, passage_sem_views_k, passage_mask, return_output_metadata) else: modeled_passage = self._modeling_layer( modeled_passage_list[-1], passage_sem_views_q, passage_sem_views_k, passage_mask) else: modeled_passage = self._modeling_layer( modeled_passage_list[-1], passage_mask) modeled_passage = self._dropout(modeled_passage) modeled_passage_list.append(modeled_passage) modeled_passage_output_metadata_list[ "modeling_layer_iter_{0:03d}".format( modeling_layer_id)] = modeled_passage_output_metadata # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat( [modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat( [modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { #"passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: try: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss except Exception as e: logging.exception(e) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] metrics_per_item = None all_reference_answers_text = [] all_best_spans = [] return_metrics_per_item = True if not self.training: metrics_per_item = [{} for x in range(batch_size)] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) # offsets = metadata[i]['token_offsets'] # start_offset = offsets[predicted_span[0]][0] # end_offset = offsets[predicted_span[1]][1] start_span = predicted_span[0] end_span = predicted_span[1] best_span_tokens = metadata[i]['passage_tokens'][ start_span:end_span + 1] best_span_string = " ".join(best_span_tokens) output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) all_best_spans.append(best_span_string) if answer_texts: curr_item_em, curr_item_f1 = self._squad_metrics( best_span_string, answer_texts, return_score=True) if not self.training and return_metrics_per_item: metrics_per_item[i]["em"] = curr_item_em metrics_per_item[i]["f1"] = curr_item_f1 all_reference_answers_text.append(answer_texts) if return_output_metadata: output_dict["output_metadata"] = { "encoded_passage": encoded_passage_output_metadata, "encoded_question": encoded_question_output_metadata, "modeling_layer": modeled_passage_output_metadata_list, } if not self.training and len(all_reference_answers_text) > 0: metrics_per_item_rouge = self.calculate_rouge( all_best_spans, all_reference_answers_text, return_metrics_per_item=return_metrics_per_item) for i, curr_metrics in enumerate(metrics_per_item_rouge): metrics_per_item[i].update(curr_metrics) if metrics_per_item is not None: output_dict['metrics'] = metrics_per_item output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward(self, # type: ignore tokens: Dict[str, torch.LongTensor], tags: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, # pylint: disable=unused-argument **kwargs) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : ``Dict[str, torch.LongTensor]``, required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. tags : ``torch.LongTensor``, optional (default = ``None``) A torch tensor representing the sequence of integer gold class labels of shape ``(batch_size, num_tokens)``. metadata : ``List[Dict[str, Any]]``, optional, (default = None) metadata containg the original words in the sentence to be tagged under a 'words' key. Returns ------- An output dictionary consisting of: logits : ``torch.FloatTensor`` The logits that are the output of the ``tag_projection_layer`` mask : ``torch.LongTensor`` The text field mask for the input tokens tags : ``List[List[int]]`` The predicted tags using the Viterbi algorithm. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. Only computed if gold label ``tags`` are provided. """ embedded_text_input = self.text_field_embedder(tokens) mask = util.get_text_field_mask(tokens) if self.dropout: embedded_text_input = self.dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) if self.dropout: encoded_text = self.dropout(encoded_text) if self._feedforward is not None: encoded_text = self._feedforward(encoded_text) logits = self.tag_projection_layer(encoded_text) best_paths = self.crf.viterbi_tags(logits, mask) # Just get the tags and ignore the score. predicted_tags = [x for x, y in best_paths] output = {"logits": logits, "mask": mask, "tags": predicted_tags} if tags is not None: # Add negative log-likelihood as loss log_likelihood = self.crf(logits, tags, mask) output["loss"] = -log_likelihood # Represent viterbi tags as "class probabilities" that we can # feed into the metrics class_probabilities = logits * 0. for i, instance_tags in enumerate(predicted_tags): for j, tag_id in enumerate(instance_tags): class_probabilities[i, j, tag_id] = 1 for metric in self.metrics.values(): metric(class_probabilities, tags, mask.float()) if self.calculate_span_f1: self._f1_metric(class_probabilities, tags, mask.float()) if metadata is not None: output["words"] = [x["words"] for x in metadata] return output
def forward(self, tokens: Dict[str, torch.Tensor], label: torch.Tensor = None) -> torch.Tensor: # In deep NLP, when sequences of tensors in different lengths are batched together, # shorter sequences get padded with zeros to make them equal length. # Masking is the process to ignore extra zeros added by padding text_mask = util.get_text_field_mask(tokens).float() # Forward pass embedded_text = self.text_field_embedder(tokens) dropped_embedded_text = self.embedding_dropout(embedded_text) encoded_tokens = self.encoder(dropped_embedded_text, text_mask) # Compute biattention. This is a special case since the inputs are the same. attention_logits = encoded_tokens.bmm( encoded_tokens.permute(0, 2, 1).contiguous()) attention_weights = util.masked_softmax(attention_logits, text_mask) encoded_text = util.weighted_sum(encoded_tokens, attention_weights) # Build the input to the integrator integrator_input = torch.cat([ encoded_tokens, encoded_tokens - encoded_text, encoded_tokens * encoded_text ], 2) integrated_encodings = self.integrator(integrator_input, text_mask) # Simple Pooling layers max_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), -1e7) max_pool = torch.max(max_masked_integrated_encodings, 1)[0] min_masked_integrated_encodings = util.replace_masked_values( integrated_encodings, text_mask.unsqueeze(2), +1e7) min_pool = torch.min(min_masked_integrated_encodings, 1)[0] mean_pool = torch.sum(integrated_encodings, 1) / torch.sum( text_mask, 1, keepdim=True) # Self-attentive pooling layer # Run through linear projection. Shape: (batch_size, sequence length, 1) # Then remove the last dimension to get the proper attention shape (batch_size, sequence length). self_attentive_logits = self._self_attentive_pooling_projection( integrated_encodings).squeeze(2) self_weights = util.masked_softmax(self_attentive_logits, text_mask) self_attentive_pool = util.weighted_sum(integrated_encodings, self_weights) pooled_representations = torch.cat( [max_pool, min_pool, mean_pool, self_attentive_pool], 1) pooled_representations_dropped = self.integrator_dropout( pooled_representations) logits = self.output_layer(pooled_representations_dropped) # In AllenNLP, the output of forward() is a dictionary. # Your output dictionary must contain a "loss" key for your model to be trained. output = {"logits": logits} if label is not None: self.accuracy(logits, label) self.f1_measure_positive(logits, label) self.f1_measure_negative(logits, label) self.f1_measure_neutral(logits, label) output["loss"] = self.loss_function(logits, label) return output
def forward(self, # type: ignore words: Dict[str, torch.LongTensor], pos_tags: torch.LongTensor, metadata: List[Dict[str, Any]], head_tags: torch.LongTensor = None, head_indices: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- words : Dict[str, torch.LongTensor], required The output of ``TextField.as_array()``, which should typically be passed directly to a ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": Tensor(batch_size, sequence_length)}``. This dictionary will have the same keys as were used for the ``TokenIndexers`` when you created the ``TextField`` representing your sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, which knows how to combine different word representations into a single vector per token in your input. pos_tags : ``torch.LongTensor``, required. The output of a ``SequenceLabelField`` containing POS tags. POS tags are required regardless of whether they are used in the model, because they are used to filter the evaluation metric to only consider heads of words which are not punctuation. head_tags : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer gold class labels for the arcs in the dependency parse. Has shape ``(batch_size, sequence_length)``. head_indices : torch.LongTensor, optional (default = None) A torch tensor representing the sequence of integer indices denoting the parent of every word in the dependency parse. Has shape ``(batch_size, sequence_length)``. Returns ------- An output dictionary consisting of: loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. arc_loss : ``torch.FloatTensor`` The loss contribution from the unlabeled arcs. loss : ``torch.FloatTensor``, optional The loss contribution from predicting the dependency tags for the gold arcs. heads : ``torch.FloatTensor`` The predicted head indices for each word. A tensor of shape (batch_size, sequence_length). head_types : ``torch.FloatTensor`` The predicted head types for each arc. A tensor of shape (batch_size, sequence_length). mask : ``torch.LongTensor`` A mask denoting the padded elements in the batch. """ embedded_text_input = self.text_field_embedder(words) if pos_tags is not None and self._pos_tag_embedding is not None: embedded_pos_tags = self._pos_tag_embedding(pos_tags) embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) elif self._pos_tag_embedding is not None: raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") mask = get_text_field_mask(words) embedded_text_input = self._input_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) batch_size, _, encoding_dim = encoded_text.size() head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) # Concatenate the head sentinel onto the sentence representation. encoded_text = torch.cat([head_sentinel, encoded_text], 1) mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) if head_indices is not None: head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) if head_tags is not None: head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) float_mask = mask.float() encoded_text = self._dropout(encoded_text) # shape (batch_size, sequence_length, arc_representation_dim) head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) # shape (batch_size, sequence_length, tag_representation_dim) head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) # shape (batch_size, sequence_length, sequence_length) attended_arcs = self.arc_attention(head_arc_representation, child_arc_representation) minus_inf = -1e8 minus_mask = (1 - float_mask) * minus_inf attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) if self.training or not self.use_mst_decoding_for_validation: predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) else: predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, child_tag_representation, attended_arcs, mask) if head_indices is not None and head_tags is not None: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=head_indices, head_tags=head_tags, mask=mask) loss = arc_nll + tag_nll evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) # We calculate attatchment scores for the whole sentence # but excluding the symbolic ROOT token at the start, # which is why we start from the second element in the sequence. self._attachment_scores(predicted_heads[:, 1:], predicted_head_tags[:, 1:], head_indices[:, 1:], head_tags[:, 1:], evaluation_mask) else: arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, child_tag_representation=child_tag_representation, attended_arcs=attended_arcs, head_indices=predicted_heads.long(), head_tags=predicted_head_tags.long(), mask=mask) loss = arc_nll + tag_nll output_dict = { "heads": predicted_heads, "head_tags": predicted_head_tags, "arc_loss": arc_nll, "tag_loss": tag_nll, "loss": loss, "mask": mask, "words": [meta["words"] for meta in metadata], "pos": [meta["pos"] for meta in metadata] } return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size, num_of_passage_tokens = passage['bert'].size() # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of # start and end spans. embedded_passage = self._text_field_embedder(passage) passage_length = embedded_passage.size(1) logits = self.qa_outputs(embedded_passage) start_logits, end_logits = logits.split(1, dim=-1) span_start_logits = start_logits.squeeze(-1) span_end_logits = end_logits.squeeze(-1) # Adding some masks with numerically stable values passage_mask = util.get_text_field_mask(passage).float() repeated_passage_mask = passage_mask.unsqueeze(1).repeat(1, 1, 1) repeated_passage_mask = repeated_passage_mask.view(batch_size, passage_length) span_start_logits = util.replace_masked_values(span_start_logits, repeated_passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, repeated_passage_mask, -1e7) output_dict: Dict[str, Any] = {} # We may have multiple instances per questions, moving to per-question intances_question_id = [insta_meta['question_id'] for insta_meta in metadata] question_instances_split_inds = np.cumsum(np.unique(intances_question_id, return_counts=True)[1])[:-1] per_question_inds = np.split(range(batch_size), question_instances_split_inds) metadata = np.split(metadata, question_instances_split_inds) # Compute the loss. if span_start is not None and len(np.argwhere(span_start.squeeze().cpu() >= 0)) > 0: # in evaluation some instances may not contain the gold answer, so we need to compute # loss only on those that do. inds_with_gold_answer = np.argwhere(span_start.view(-1).cpu().numpy() >= 0) inds_with_gold_answer = inds_with_gold_answer.squeeze() if len(inds_with_gold_answer) > 1 else inds_with_gold_answer if len(inds_with_gold_answer)>0: loss = nll_loss(util.masked_log_softmax(span_start_logits[inds_with_gold_answer], \ repeated_passage_mask[inds_with_gold_answer]),\ span_start.view(-1)[inds_with_gold_answer], ignore_index=-1) loss += nll_loss(util.masked_log_softmax(span_end_logits[inds_with_gold_answer], \ repeated_passage_mask[inds_with_gold_answer]),\ span_end.view(-1)[inds_with_gold_answer], ignore_index=-1) output_dict["loss"] = loss # This is a hack for cases in which gold answer is not provided so we cannot compute loss... if 'loss' not in output_dict: output_dict["loss"] = torch.cuda.FloatTensor([0], device=span_end_logits.device) \ if torch.cuda.is_available() else torch.FloatTensor([0]) # Compute F1 and preparing the output dictionary. output_dict['best_span_str'] = [] output_dict['qid'] = [] # getting best span prediction for best_span = self._get_example_predications(span_start_logits, span_end_logits, self._max_span_length) best_span_cpu = best_span.detach().cpu().numpy() span_start_logits_numpy = span_start_logits.data.cpu().numpy() span_end_logits_numpy = span_end_logits.data.cpu().numpy() # Iterating over every question (which may contain multiple instances, one per chunk) for question_inds, question_instances_metadata in zip(per_question_inds, metadata): best_span_ind = np.argmax(span_start_logits_numpy[question_inds, best_span_cpu[question_inds][:, 0]] + span_end_logits_numpy[question_inds, best_span_cpu[question_inds][:, 1]]) best_span_logit = np.max(span_start_logits_numpy[question_inds, best_span_cpu[question_inds][:, 0]] + span_end_logits_numpy[question_inds, best_span_cpu[question_inds][:, 1]]) passage_str = question_instances_metadata[best_span_ind]['original_passage'] offsets = question_instances_metadata[best_span_ind]['token_offsets'] predicted_span = best_span_cpu[question_inds[best_span_ind]] start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] # Note: this is a hack, because AllenNLP, when predicting, expects a value for each instance. # But we may have more than 1 chunk per question, and thus less output strings than instances for i in range(len(question_inds)): output_dict['best_span_str'].append(best_span_string) output_dict['qid'].append(question_instances_metadata[best_span_ind]['question_id']) f1_score = 0.0 EM_score = 0.0 gold_answer_texts = question_instances_metadata[best_span_ind]['answer_texts_list'] if gold_answer_texts: f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score,best_span_string,gold_answer_texts) EM_score = squad_eval.metric_max_over_ground_truths(squad_eval.exact_match_score, best_span_string,gold_answer_texts) self._official_f1(100 * f1_score) self._official_EM(100 * EM_score) # TODO move to predict if self._predictions_file is not None: with open(self._predictions_file,'a') as f: f.write(json.dumps({'question_id':question_instances_metadata[best_span_ind]['question_id'], \ 'best_span_logit':float(best_span_logit), \ 'f1':100 * f1_score, 'EM':100 * EM_score, 'best_span_string':best_span_string,\ 'gold_answer_texts':gold_answer_texts, \ 'qas_used_fraction':1.0}) + '\n') return output_dict