def test_logsumexp(self): # First a simple example where we add probabilities in log space. tensor = torch.FloatTensor([[.4, .1, .2]]) log_tensor = tensor.log() log_summed = util.logsumexp(log_tensor, dim=-1, keepdim=False) assert_almost_equal(log_summed.exp().data.numpy(), [.7]) log_summed = util.logsumexp(log_tensor, dim=-1, keepdim=True) assert_almost_equal(log_summed.exp().data.numpy(), [[.7]]) # Then some more atypical examples, and making sure this will work with how we handle # log masks. tensor = torch.FloatTensor([[float('-inf'), 20.0]]) assert_almost_equal(util.logsumexp(tensor).data.numpy(), [20.0]) tensor = torch.FloatTensor([[-200.0, 20.0]]) assert_almost_equal(util.logsumexp(tensor).data.numpy(), [20.0]) tensor = torch.FloatTensor([[20.0, 20.0], [-200.0, 200.0]]) assert_almost_equal(util.logsumexp(tensor, dim=0).data.numpy(), [20.0, 200.0])
def _input_likelihood(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Computes the (batch_size,) denominator term for the log-likelihood, which is the sum of the likelihoods across all possible state sequences. """ batch_size, sequence_length, num_tags = logits.size() # Transpose batch size and sequence dimensions mask = mask.float().transpose(0, 1).contiguous() logits = logits.transpose(0, 1).contiguous() # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the # transitions to the initial states and the logits for the first timestep. if self.include_start_end_transitions: alpha = self.start_transitions.view(1, num_tags) + logits[0] else: alpha = logits[0] # For each i we compute logits for the transitions from timestep i-1 to timestep i. # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are # (instance, current_tag, next_tag) for i in range(1, sequence_length): # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis. emit_scores = logits[i].view(batch_size, 1, num_tags) # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis. transition_scores = self.transitions.view(1, num_tags, num_tags) # Alpha is for the current_tag, so we broadcast along the next_tag axis. broadcast_alpha = alpha.view(batch_size, num_tags, 1) # Add all the scores together and logexp over the current_tag axis inner = broadcast_alpha + emit_scores + transition_scores # In valid positions (mask == 1) we want to take the logsumexp over the current_tag dimension # of ``inner``. Otherwise (mask == 0) we want to retain the previous alpha. alpha = (util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1)) # Every sequence needs to end with a transition to the stop_tag. if self.include_start_end_transitions: stops = alpha + self.end_transitions.view(1, num_tags) else: stops = alpha # Finally we log_sum_exp along the num_tags dim, result is (batch_size,) return util.logsumexp(stops)
def decode(self, initial_state: State, transition_function: TransitionFunction, supervision: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, torch.Tensor]: targets, target_mask = supervision beam_search = ConstrainedBeamSearch(self._beam_size, targets, target_mask) finished_states: Dict[int, List[State]] = beam_search.search(initial_state, transition_function) loss = 0 for instance_states in finished_states.values(): scores = [state.score[0].view(-1) for state in instance_states] loss += -util.logsumexp(torch.cat(scores)) return {'loss': loss / len(finished_states)}
def decode(self, initial_state: DecoderState, decode_step: DecoderStep, supervision: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, torch.Tensor]: targets, target_mask = supervision # If self._beam_size is not set, we use a beam size that ensures we keep all of the # sequences. beam_size = self._beam_size or targets.size(1) beam_search = ConstrainedBeamSearch(beam_size, targets, target_mask) finished_states: Dict[int, List[DecoderState]] = beam_search.search(initial_state, decode_step) loss = 0 for instance_states in finished_states.values(): scores = [state.score[0].view(-1) for state in instance_states] loss += -util.logsumexp(torch.cat(scores)) return {'loss': loss / len(finished_states)}
def _merge_output_dicts( self, candidate_output_dicts: List[Dict[str, torch.Tensor]] ) -> Dict[str, torch.Tensor]: # TODO (pradeep): Also take the worlds and actions and reconstruct logical forms. The issue is that the # actions are bottom-up, and DomainLanguage instances can only handle top-down sequences. # TODO (pradeep): These losses are batch averaged. Is that a problem? # (max_num_inputs,) losses = torch.stack( [output["loss"] for output in candidate_output_dicts]) # Losses are negative log-likelihoods. The final loss we need is be the negative log of sum of all # likelihoods. output_dict = {"loss": -util.logsumexp(-losses)} if "class_log_probabilities" in candidate_output_dicts[0]: # This means we have an k-best list of sequences. # List of (batch_size, k) tensors. candidate_log_probabilities = [ output["class_log_probabilities"] for output in candidate_output_dicts ] # (batch_size, max_num_inputs * k) log_probabilities = torch.cat(candidate_log_probabilities, dim=-1) # We now merge the predictions. One thing to worry about here is that the sequence lengths may not # necessarily be equal to the ``max_decoding_steps``. Given a batch of instances, the BeamSearch stops # searching if all instances have reached end states. So we need to do some padding here before # concatenating output candidates from different input sequences. padded_predictions: List[torch.Tensor] = [] for candidate_output_dict in candidate_output_dicts: candidate_predictions = candidate_output_dict["predictions"] batch_size, num_sequences, sequence_length = candidate_predictions.size( ) if sequence_length < self._max_decoding_steps: padding_length = self._max_decoding_steps - sequence_length padding = candidate_predictions.new_full( (batch_size, num_sequences, padding_length), self._end_index) candidate_predictions = torch.cat( [candidate_predictions, padding], dim=2) padded_predictions.append(candidate_predictions) # (batch_size, max_num_inputs * k, max_decoding_steps) predictions = torch.cat(padded_predictions, dim=1) sorted_log_probabilities, indices = torch.sort(log_probabilities, descending=True) # (batch_size, max_num_inputs * k, max_decoding_steps) indices_for_selection = indices.unsqueeze(-1).repeat_interleave( self._max_decoding_steps, dim=2) sorted_predictions = predictions.gather(1, indices_for_selection) output_dict["class_log_probabilities"] = sorted_log_probabilities output_dict["predictions"] = sorted_predictions # Now we rank the action sequences according to the log probabilities of their best decoded sequences. # (batch_size, max_num_inputs) best_log_probabilities = torch.stack([ log_probs[:, 0] for log_probs in candidate_log_probabilities ]).transpose(0, 1) # (batch_size, max_num_inputs) _, ranked_input_indices = torch.sort(best_log_probabilities, 1, descending=True) int_ranked_input_indices = [[ int(x) for x in instance_indices.data.cpu() ] for instance_indices in ranked_input_indices] output_dict[ "sorted_logical_form_indices"] = int_ranked_input_indices return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], numbers_in_passage: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_add_sub_expressions: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ, unused-argument embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector ], dim=-1) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat([ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation ], dim=-1) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax( span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( span_end_logits, passage_mask) passage_span_start_logits = util.replace_masked_values( span_start_logits, passage_mask, -1e7) passage_span_end_logits = util.replace_masked_values( span_end_logits, passage_mask, -1e7) # Shape: (batch_size, 2) best_passage_span = \ BidirectionalAttentionFlow.get_best_span(passage_span_start_logits, passage_span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "passage_span_start_probs": passage_span_start_log_probs.exp(), "passage_span_end_probs": passage_span_end_log_probs.exp() } # If answer is given, compute the loss for training. if answer_as_passage_spans is not None: # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_as_passage_spans[:, :, 0] gold_passage_span_ends = answer_as_passage_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = (gold_passage_span_starts != -1).long() clamped_gold_passage_span_starts = util.replace_masked_values( gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = util.replace_masked_values( gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = \ torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_passage_span_ends = \ torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = \ log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = \ util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e32) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp( log_likelihood_for_passage_spans) output_dict[ "loss"] = -log_marginal_likelihood_for_passage_span.mean() # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) # We did not consider multi-mention answers here passage_str = metadata[i]['original_passage'] offsets = metadata[i]['passage_token_offsets'] predicted_span = tuple( best_passage_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_answer_str = passage_str[start_offset:end_offset] output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(best_answer_str) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(best_answer_str, answer_annotations) return output_dict
def forward( self, # type: ignore text: TextFieldTensors, spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ # Parameters text : `TextFieldTensors`, required. The output of a `TextField` representing the text of the document. spans : `torch.IntTensor`, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of indices into the text of the document. span_labels : `torch.IntTensor`, optional (default = None). A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : `List[Dict[str, Any]]`, optional (default = None). A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys from this dictionary, which respectively have the original text and the annotated gold coreference clusters for that instance. # Returns An output dictionary consisting of: top_spans : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : `torch.IntTensor` A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : `torch.IntTensor` A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : `torch.FloatTensor`, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1) # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) # Shape: (batch_size, num_spans) span_mention_scores = self._mention_scorer( self._mention_feedforward(span_embeddings)).squeeze(-1) # Shape: # (batch_size, num_spans) * 3 top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk( span_mention_scores, span_mask, num_spans_to_keep) top_span_mention_scores = top_span_mention_scores.unsqueeze(-1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Shape: (batch_size, num_spans_to_keep, embedding_size) top_span_embeddings = util.batched_index_select( span_embeddings, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) ( valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask, ) = self._generate_valid_antecedents( # noqa num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask, ) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents, } if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore question_passage: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, mask_indices: torch.LongTensor, #num_spans: torch.LongTensor = None, impossible_answer: torch.LongTensor = None, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_expressions: torch.LongTensor = None, answer_as_expressions_extra: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, seqlen) question_passage_tokens = question_passage["tokens"] # Shape: (batch_size, seqlen) pad_mask = question_passage["mask"] # Shape: (batch_size, seqlen) seqlen_ids = question_passage["tokens-type-ids"] max_seqlen = question_passage_tokens.shape[-1] batch_size = question_passage_tokens.shape[0] # Shape: (batch_size, 3) mask = mask_indices.squeeze(-1) # Shape: (batch_size, seqlen) cls_sep_mask = \ torch.ones(pad_mask.shape, device=pad_mask.device).long().scatter(1, mask, torch.zeros(mask.shape, device=mask.device).long()) # Shape: (batch_size, seqlen) passage_mask = seqlen_ids * pad_mask * cls_sep_mask # Shape: (batch_size, seqlen) question_mask = (1 - seqlen_ids) * pad_mask * cls_sep_mask # Shape: (batch_size, seqlen, bert_dim) bert_out, _ = self.BERT(question_passage_tokens, seqlen_ids, pad_mask, output_all_encoded_layers=False) # Shape: (batch_size, qlen, bert_dim) question_end = max(mask[:, 1]) question_out = bert_out[:, :question_end] # Shape: (batch_size, qlen) question_mask = question_mask[:, :question_end] # Shape: (batch_size, out) question_vector = self.summary_vector(question_out, question_mask, "question") passage_out = bert_out del bert_out # Shape: (batch_size, bert_dim) passage_vector = self.summary_vector(passage_out, passage_mask) if "arithmetic" in self.answering_abilities and self.arithmetic == "advanced": arithmetic_summary = self.summary_vector(passage_out, pad_mask, "arithmetic") # arithmetic_summary = self.summary_vector(question_out, question_mask, "arithmetic") # Shape: (batch_size, # of numbers in the passage) if number_indices.dim() == 3: number_indices = number_indices[:, :, 0].long() number_mask = (number_indices != -1).long() clamped_number_indices = util.replace_masked_values( number_indices, number_mask, 0) encoded_numbers = torch.gather( passage_out, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, passage_out.size(-1))) op_mask = torch.ones((batch_size, self.num_ops + 1), device=number_mask.device).long() options_mask = torch.cat([op_mask, number_mask], -1) ops = self.op_embeddings( torch.arange(self.num_ops + 1, device=number_mask.device).expand(batch_size, -1)) options = torch.cat([self.Wo(ops), self.Wc(encoded_numbers)], 1) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = \ self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1)) #print(impossible_answer) #impossible_answer[impossible_answer==0] = -1 #print(impossible_answer) #print(answer_ability_logits) #answer_ability_logits[:, -1] = answer_ability_logits[:, -1] * impossible_answer.float() #print(answer_ability_logits) answer_ability_log_probs = torch.nn.functional.log_softmax( answer_ability_logits, -1) #answer_ability_log_probs_filtered = answer_ability_log_probs.clone() #print(answer_ability_log_probs_filtered.size()) #print(answer_ability_log_probs_filtered[:-1]) #answer_ability_log_probs_filtered[:,-1] = answer_ability_log_probs_filtered[:,-1] * impossible_answer.float() #print(answer_ability_log_probs_filtered) #print(">>>>>>>>>>>>>>>>>>.") best_answer_ability = torch.argmax(answer_ability_log_probs, 1) if "counting" in self.answering_abilities: count_number_log_probs, best_count_number = self._count_module( passage_vector) if "passage_span_extraction" in self.answering_abilities: passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span = \ self._passage_span_module(passage_out, passage_mask) if "question_span_extraction" in self.answering_abilities: question_span_start_log_probs, question_span_end_log_probs, best_question_span = \ self._question_span_module(passage_vector, question_out, question_mask) if "arithmetic" in self.answering_abilities: if self.arithmetic == "base": number_mask = (number_indices[:, :, 0].long() != -1).long() number_sign_log_probs, best_signs_for_numbers, number_mask = \ self._base_arithmetic_module(passage_vector, passage_out, number_indices, number_mask) else: arithmetic_logits, best_expression = \ self._adv_arithmetic_module(arithmetic_summary, self.max_explen, options, options_mask, \ passage_out, pad_mask) shapes = arithmetic_logits.shape if (1 - (arithmetic_logits != arithmetic_logits)).sum() != ( shapes[0] * shapes[1] * shapes[2]): print("bad logits") arithmetic_logits = torch.rand( shapes, device=arithmetic_logits.device, requires_grad=True) output_dict = {} if self.training: # If answer is given, compute the loss. if answer_as_passage_spans is not None or answer_as_question_spans is not None \ or answer_as_expressions is not None or answer_as_counts is not None: log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": log_marginal_likelihood_for_passage_span = \ self._passage_span_log_likelihood(answer_as_passage_spans, passage_span_start_log_probs, passage_span_end_log_probs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_passage_span) elif answering_ability == "question_span_extraction": log_marginal_likelihood_for_question_span = \ self._question_span_log_likelihood(answer_as_question_spans, question_span_start_log_probs, question_span_end_log_probs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_question_span) elif answering_ability == "arithmetic": if self.arithmetic == "base": log_marginal_likelihood_for_arithmetic = \ self._base_arithmetic_log_likelihood(answer_as_expressions, number_sign_log_probs, number_mask, answer_as_expressions_extra, metadata) else: max_explen = answer_as_expressions.shape[-1] possible_exps = answer_as_expressions.shape[1] limit = min(possible_exps, 1000) log_marginal_likelihood_for_arithmetic = \ self._adv_arithmetic_log_likelihood(arithmetic_logits[:,:max_explen,:], answer_as_expressions[:,:limit,:].long()) log_marginal_likelihood_list.append( log_marginal_likelihood_for_arithmetic) elif answering_ability == "counting": log_marginal_likelihood_for_count = \ self._count_log_likelihood(answer_as_counts, count_number_log_probs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_count) elif answering_ability == "answer_exists": impossible_answer[impossible_answer == 0] = -1e7 impossible_answer[impossible_answer == 1] = 0 log_marginal_likelihood_list.append( impossible_answer.type_as(passage_out)) else: raise ValueError( f"Unsupported answering ability: {answering_ability}" ) if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abiliti all_log_marginal_likelihoods = torch.stack( log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs marginal_log_likelihood = util.logsumexp( all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = -marginal_log_likelihood.mean() else: with torch.no_grad(): # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] for i in range(batch_size): if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[ best_answer_ability[i]] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} # We did not consider multi-mention answers here if predicted_ability_str == "passage_span_extraction": answer_json["answer_type"] = "passage_span" answer_json["value"], answer_json["spans"] = \ self._span_prediction(question_passage_tokens[i], best_passage_span[i]) elif predicted_ability_str == "question_span_extraction": answer_json["answer_type"] = "question_span" answer_json["value"], answer_json["spans"] = \ self._span_prediction(question_passage_tokens[i], best_question_span[i]) elif predicted_ability_str == "arithmetic": # plus_minus combination answer answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]['original_numbers'] if self.arithmetic == "base": answer_json["value"], answer_json["numbers"] = \ self._base_arithmetic_prediction(original_numbers, number_indices[i], best_signs_for_numbers[i]) else: answer_json["value"], answer_json["expression"] = \ self._adv_arithmetic_prediction(original_numbers, best_expression[i]) elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" answer_json["value"], answer_json["count"] = \ self._count_prediction(best_count_number[i]) elif predicted_ability_str == "answer_exists": answer_json["answer_type"] = "passage_span" answer_json["value"] = "impossible" output_dict["question_id"].append( metadata[i]["question_id"]) output_dict["answer"].append(answer_json) output_dict["prediction"] = answer_json["value"] return output_dict
def _gather_final_log_probs( self, generation_log_probs: torch.Tensor, copy_log_probs: torch.Tensor, state: Dict[str, torch.Tensor]) -> torch.Tensor: """ Combine copy probabilities with generation probabilities for matching tokens. Parameters ---------- generation_log_probs : ``torch.Tensor`` Shape: `(group_size, target_vocab_size)` copy_log_probs : ``torch.Tensor`` Shape: `(group_size, trimmed_source_length)` state : ``Dict[str, torch.Tensor]`` Returns ------- torch.Tensor Shape: `(group_size, target_vocab_size + trimmed_source_length)`. """ _, trimmed_source_length = state["source_to_target"].size() source_token_ids = state["source_token_ids"] # shape: [(batch_size, *)] modified_log_probs_list: List[torch.Tensor] = [generation_log_probs] for i in range(trimmed_source_length): # shape: (group_size,) copy_log_probs_slice = copy_log_probs[:, i] # `source_to_target` is a matrix of shape (group_size, trimmed_source_length) # where element (i, j) is the vocab index of the target token that matches the jth # source token in the ith group, if there is one, or the index of the OOV symbol otherwise. # We'll use this to add copy scores to corresponding generation scores. # shape: (group_size,) source_to_target_slice = state["source_to_target"][:, i] # The OOV index in the source_to_target_slice indicates that the source # token is not in the target vocab, so we don't want to add that copy score # to the OOV token. copy_log_probs_to_add_mask = (source_to_target_slice != self._oov_index).float() copy_log_probs_to_add = copy_log_probs_slice + ( copy_log_probs_to_add_mask + 1e-45).log() # shape: (batch_size, 1) copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1) # shape: (batch_size, 1) selected_generation_log_probs = generation_log_probs.gather( 1, source_to_target_slice.unsqueeze(-1)) combined_scores = util.logsumexp( torch.cat( (selected_generation_log_probs, copy_log_probs_to_add), dim=1)) generation_log_probs.scatter_(-1, source_to_target_slice.unsqueeze(-1), combined_scores.unsqueeze(-1)) # We have to combine copy scores for duplicate source tokens so that # we can find the overall most likely source token. So, if this is the first # occurence of this particular source token, we add the log_probs from all other # occurences, otherwise we zero it out since it was already accounted for. if i < (trimmed_source_length - 1): # Sum copy scores from future occurences of source token. # shape: (group_size, trimmed_source_length - i) source_future_occurences = (source_token_ids[:, (i + 1):] == source_token_ids[:, i].unsqueeze(-1)).float() # pylint: disable=line-too-long # shape: (group_size, trimmed_source_length - i) future_copy_log_probs = copy_log_probs[:, (i + 1):] + ( source_future_occurences + 1e-45).log() # shape: (group_size, 1 + trimmed_source_length - i) combined = torch.cat((copy_log_probs_slice.unsqueeze(-1), future_copy_log_probs), dim=-1) # shape: (group_size,) copy_log_probs_slice = util.logsumexp(combined) if i > 0: # Remove copy log_probs that we have already accounted for. # shape: (group_size, i) source_previous_occurences = source_token_ids[:, 0: i] == source_token_ids[:, i].unsqueeze( -1) # shape: (group_size,) duplicate_mask = (source_previous_occurences.sum( dim=-1) == 0).float() copy_log_probs_slice = copy_log_probs_slice + (duplicate_mask + 1e-45).log() # Finally, we zero-out copy scores that we added to the generation scores # above so that we don't double-count them. # shape: (group_size,) left_over_copy_log_probs = copy_log_probs_slice + ( 1.0 - copy_log_probs_to_add_mask + 1e-45).log() modified_log_probs_list.append( left_over_copy_log_probs.unsqueeze(-1)) # shape: (group_size, target_vocab_size + trimmed_source_length) modified_log_probs = torch.cat(modified_log_probs_list, dim=-1) return modified_log_probs
def decode( self, initial_state: State, transition_function: TransitionFunction, supervision: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, torch.Tensor]: targets, target_mask = supervision # batch_size x inter_size x action_size x index_size(no use) assert len(targets.size()) == 4 # -> batch_size * inter_size x action_size batch_size, inter_size, _, _ = targets.size() # TODO: we must keep the shape because the loss_mask targets = targets.reshape(batch_size * inter_size, -1) target_mask = target_mask.reshape(batch_size * inter_size, -1) inter_mask = target_mask.sum(dim=1).ne(0) # un squeeze beam search dimension targets = targets.unsqueeze(dim=1) target_mask = target_mask.unsqueeze(dim=1) beam_search = ConstrainedBeamSearch(self._beam_size, targets, target_mask) finished_states: Dict[int, List[State]] = beam_search.search( initial_state, transition_function) inter_count = inter_mask.view(batch_size, inter_size).sum(dim=0).float() if 0 not in inter_count: inter_ratio = 1.0 / inter_count else: inter_ratio = torch.ones_like(inter_count) loss = 0 for iter_ind, instance_states in finished_states.items(): scores = [state.score[0].view(-1) for state in instance_states] lens = [len(state.action_history[0]) for state in instance_states] if not len(lens): continue # the i-round of an interaction, starting from 0 cur_inter = iter_ind % inter_size if self._re_weight: loss_coefficient = inter_ratio[cur_inter] else: loss_coefficient = 1.0 if self._loss_mask <= cur_inter: continue cur_loss = -util.logsumexp( torch.cat(scores)) / statistics.mean(lens) loss += loss_coefficient * cur_loss if self._re_weight: return {'loss': loss / len(inter_count)} elif self._loss_mask < inter_size: valid_counts = inter_count[:self._loss_mask].sum() return {'loss': loss / valid_counts} else: return {'loss': loss / len(finished_states)}
def forward( self, # type: ignore question_passage_tokens: torch.LongTensor, question_passage_token_type_ids: torch.LongTensor, question_passage_special_tokens_mask: torch.LongTensor, question_passage_pad_mask: torch.LongTensor, first_wordpiece_mask: torch.LongTensor, metadata: List[Dict[str, Any]], wordpiece_indices: torch.LongTensor = None, number_indices: torch.LongTensor = None, answer_as_expressions: torch.LongTensor = None, answer_as_expressions_extra: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, answer_as_text_to_disjoint_bios: torch.LongTensor = None, answer_as_list_of_bios: torch.LongTensor = None, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, span_bio_labels: torch.LongTensor = None, is_bio_mask: torch.LongTensor = None) -> Dict[str, Any]: # pylint: disable=arguments-differ question_passage_special_tokens_mask = ( 1 - question_passage_special_tokens_mask) batch_size = question_passage_tokens.shape[0] head_count = len(self._heads) # TODO: (not important) Create a new field that is converted to Dict[str, torch.LongTensor] gold_answer_representations = { 'answer_as_expressions': answer_as_expressions, 'answer_as_expressions_extra': answer_as_expressions_extra, 'answer_as_passage_spans': answer_as_passage_spans, 'answer_as_question_spans': answer_as_question_spans, 'answer_as_counts': answer_as_counts, 'answer_as_text_to_disjoint_bios': answer_as_text_to_disjoint_bios, 'answer_as_list_of_bios': answer_as_list_of_bios, 'span_bio_labels': span_bio_labels } has_answer = False for answer_representation in gold_answer_representations.values(): if answer_representation is not None: has_answer = True break # Shape: (batch_size, seqlen) passage_mask = question_passage_token_type_ids * question_passage_pad_mask * question_passage_special_tokens_mask # Shape: (batch_size, seqlen) question_mask = \ (1 - question_passage_token_type_ids) * question_passage_pad_mask * question_passage_special_tokens_mask question_and_passage_mask = question_mask | passage_mask # Use pre-trained model to compute the representations of the input data # Shape: (batch_size, seqlen, bert_dim) token_type_ids = question_passage_token_type_ids if not self._pretrained_model.startswith( 'roberta-') else None token_representations = self._transformers_model( question_passage_tokens, token_type_ids=token_type_ids, attention_mask=question_passage_pad_mask)[0] # if desired, compute the passage summary vector if self._passage_summary_vector_module is not None: # Shape: (batch_size, bert_dim) passage_summary_vector = self.summary_vector( token_representations, passage_mask, 'passage') else: passage_summary_vector = None # if desired, compute the question summary vector if self._question_summary_vector_module is not None: # Shape: (batch_size, bert_dim) question_summary_vector = self.summary_vector( token_representations, question_mask, 'question') else: question_summary_vector = None if head_count > 1: # use the head_predictor with the summary vectors # Shape: (batch_size, number_of_abilities) answer_ability_logits = \ self._head_predictor(torch.cat([passage_summary_vector, question_summary_vector], -1)) answer_ability_log_probs = torch.nn.functional.log_softmax( answer_ability_logits, -1) top_answer_abilities = torch.argsort(answer_ability_log_probs, descending=True) else: top_answer_abilities = torch.zeros(batch_size, 1, dtype=torch.int) kwargs = { 'token_representations': token_representations, 'passage_summary_vector': passage_summary_vector, 'question_summary_vector': question_summary_vector, 'gold_answer_representations': gold_answer_representations, 'question_and_passage_mask': question_and_passage_mask, 'first_wordpiece_mask': first_wordpiece_mask, 'is_bio_mask': is_bio_mask, 'wordpiece_indices': wordpiece_indices, 'number_indices': number_indices, 'passage_mask': passage_mask, 'question_mask': question_mask, 'question_passage_special_tokens_mask': question_passage_special_tokens_mask } head_outputs = {} for head_name, head in self._heads.items(): head_outputs[head_name] = head(**kwargs) output_dict = {} # If answer is given, compute the loss. if has_answer: log_marginal_likelihood_list = [] for head_name, head in self._heads.items(): # The marginal log likelihood is calculated for each head separately log_marginal_likelihood = head.gold_log_marginal_likelihood( **kwargs, **head_outputs[head_name]) """ log probability for each head to be selected is added (which is like AND/multiplication, but in logspace). """ log_marginal_likelihood_list.append(log_marginal_likelihood) if head_count > 1: # Add the ability probabilities if there is more than one ability """ all the likelihoods are combined by summation (this is like OR, as we want to maximize the probability that any of the heads is right, and not that all of them are right). """ all_log_marginal_likelihoods = torch.stack( log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs marginal_log_likelihood = util.logsumexp( all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] # Finally, we compute the mean loss across the batch elements. # put the loss in the output dictionary output_dict['loss'] = -1 * marginal_log_likelihood.mean() with torch.no_grad(): # Compute the metrics and add fields to the output if metadata is not None and self._training_evaluation: if not self.training: output_dict['passage_id'] = [] output_dict['query_id'] = [] output_dict['answer'] = [] output_dict['predicted_ability'] = [] output_dict['maximizing_ground_truth'] = [] output_dict['em'] = [] output_dict['f1'] = [] output_dict['max_passage_length'] = [] if self._output_all_answers: output_dict['all_answers'] = [] i = 0 no_fallback = False ordered_lookup_index = 0 while i < batch_size: predicting_head_index = top_answer_abilities[i][ ordered_lookup_index].item() predicting_head_name = self.heads_indices( )[predicting_head_index] predicting_head = self._heads[predicting_head_name] # construct the arguments to be used for a batch instance prediction instance_kwargs = { 'q_text': metadata[i]['original_question'], 'p_text': metadata[i]['original_passage'], 'qp_tokens': metadata[i]['question_passage_tokens'], 'question_passage_wordpieces': metadata[i]['question_passage_wordpieces'], 'original_numbers': metadata[i]['original_numbers'] if 'original_numbers' in metadata[i] else None, } # keys that cannot be passed because # they are not batch-based in their first level or None unpassable_keys = ['gold_answer_representations'] for key, value in instance_kwargs.items(): if value is None: unpassable_keys.append(key) for key in unpassable_keys: if key in instance_kwargs: del instance_kwargs[key] for key, value in kwargs.items(): if value is not None and key not in unpassable_keys: instance_kwargs[key] = value[i] for key, value in head_outputs[predicting_head_name].items( ): if key not in unpassable_keys: instance_kwargs[key] = value[i] # get prediction for an instance in the batch answer_json = predicting_head.decode_answer( **instance_kwargs) if len(answer_json['value']) != 0 or no_fallback: # for the next in the batch ordered_lookup_index = 0 no_fallback = False else: if not self.training: logger.info( "Answer was empty for head: %s, query_id: %s", predicting_head_name, metadata[i]['question_id']) ordered_lookup_index += 1 if ordered_lookup_index == head_count: no_fallback = True ordered_lookup_index = 0 continue maximizing_ground_truth = None em, f1 = None, None answer_annotations = metadata[i].get( 'answer_annotations', []) if answer_annotations: (em, f1), maximizing_ground_truth = self._metrics.call( answer_json['value'], answer_annotations, predicting_head_name) if not self.training: output_dict['passage_id'].append( metadata[i]['passage_id']) output_dict['query_id'].append( metadata[i]['question_id']) output_dict['answer'].append(answer_json) output_dict['predicted_ability'].append( predicting_head_name) output_dict['maximizing_ground_truth'].append( maximizing_ground_truth) output_dict['em'].append(em) output_dict['f1'].append(f1) output_dict['max_passage_length'].append( metadata[i]['max_passage_length']) if self._output_all_answers: answers_dict = {} output_dict['all_answers'].append(answers_dict) for j in range(len(self._heads)): predicting_head_index = top_answer_abilities[ i][j].item() predicting_head_name = self.heads_indices( )[predicting_head_index] predicting_head = self._heads[ predicting_head_name] # construct the arguments to be used for a batch instance prediction instance_kwargs = { 'q_text': metadata[i]['original_question'], 'p_text': metadata[i]['original_passage'], 'qp_tokens': metadata[i]['question_passage_tokens'], 'question_passage_wordpieces': metadata[i]['question_passage_wordpieces'], 'original_numbers': metadata[i]['original_numbers'] if 'original_numbers' in metadata[i] else None, } # keys that cannot be passed because # they are not batch-based in their first level or None unpassable_keys = [ 'gold_answer_representations' ] for key, value in instance_kwargs.items(): if value is None: unpassable_keys.append(key) for key in unpassable_keys: if key in instance_kwargs: del instance_kwargs[key] for key, value in kwargs.items(): if value is not None and key not in unpassable_keys: instance_kwargs[key] = value[i] for key, value in head_outputs[ predicting_head_name].items(): if key not in unpassable_keys: instance_kwargs[key] = value[i] # get prediction for an instance in the batch answer_json = predicting_head.decode_answer( **instance_kwargs) answer_json[ 'probability'] = torch.nn.functional.softmax( answer_ability_logits, -1)[i][predicting_head_index].item() answers_dict[ predicting_head_name] = answer_json i += 1 return output_dict
def forward( self, # type: ignore question_field: Dict[str, torch.LongTensor], visual_feat: torch.Tensor, pos: torch.Tensor, image_id: List[str], gold_question_attentions: torch.Tensor = None, identifier: List[str] = None, logical_form: List[str] = None, actions: List[List[ProductionRule]] = None, target_action_sequence: torch.LongTensor = None, gold_object_choices: torch.Tensor = None, denotation: torch.Tensor = None, ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ batch_size, obj_num, feat_size = visual_feat.size() assert obj_num == 36 and feat_size == 2048 text_masks = util.get_text_field_mask(question_field) (l_orig, v_orig, text, vis_only), x_orig = self._encoder( question_field[self._tokens_namespace], text_masks, visual_feat, pos) text_masks = text_masks.float() # NOTE: Taking the lxmert output before cross modality layer (which is the same for both images) # Can also try concatenating (dim=-1) the two encodings encoded_sentence = text initial_state = self._get_initial_state(encoded_sentence, text_masks, actions) initial_state.debug_info = [[] for _ in range(batch_size)] if target_action_sequence is not None: # Remove the trailing dimension (from ListField[ListField[IndexField]]). target_action_sequence = target_action_sequence.squeeze(-1) target_mask = target_action_sequence != self._action_padding_index else: target_mask = None outputs: Dict[str, torch.Tensor] = {} losses = [] if (self.training or self._use_gold_program_for_eval ) and target_action_sequence is not None: # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we # unsqueeze it for the MML trainer. search = ConstrainedBeamSearch( beam_size=None, allowed_sequences=target_action_sequence.unsqueeze(1), allowed_sequence_mask=target_mask.unsqueeze(1), ) final_states = search.search(initial_state, self._transition_function) if self._training_batches_so_far < self._num_parse_only_batches: for batch_index in range(batch_size): if not final_states[batch_index]: logger.error( f"No pogram found for batch index {batch_index}") continue losses.append(-final_states[batch_index][0].score[0]) else: final_states = self._beam_search.search( self._max_decoding_steps, initial_state, self._transition_function, keep_final_unfinished_states=False, ) action_mapping = {} for action_index, action in enumerate(actions[0]): action_mapping[action_index] = action[0] outputs: Dict[str, Any] = {"action_mapping": action_mapping} outputs["best_action_sequence"] = [] outputs["debug_info"] = [] if self._nmn_settings["mask_non_attention"]: zero_one_mult = torch.zeros_like(gold_question_attentions) zero_one_mult.copy_(gold_question_attentions) zero_one_mult[:, :, 0] = 1.0 # sep_indices = text_masks.argmax(1).long() sep_indices = ( (text_masks.long() * (1 + torch.arange(text_masks.shape[1]).unsqueeze(0).repeat( batch_size, 1).to(text_masks.device))).argmax(1).long()) sep_indices = (sep_indices.unsqueeze(1).repeat( 1, gold_question_attentions.shape[2]).unsqueeze(1).repeat( 1, gold_question_attentions.shape[1], 1)) indices_dim2 = (torch.arange( gold_question_attentions.shape[2]).unsqueeze(0).repeat( gold_question_attentions.shape[0], gold_question_attentions.shape[1], 1, ).to(sep_indices.device).long()) zero_one_mult = torch.where( sep_indices == indices_dim2, torch.ones_like(zero_one_mult), zero_one_mult, ).float() reshaped_questions = ( question_field[self._tokens_namespace].unsqueeze(1).repeat( 1, gold_question_attentions.shape[1], 1).view(-1, gold_question_attentions.shape[-1])) reshaped_visual_feat = (visual_feat.unsqueeze(1).repeat( 1, gold_question_attentions.shape[1], 1, 1).view(-1, obj_num, visual_feat.shape[-1])) reshaped_pos = (pos.unsqueeze(1).repeat( 1, gold_question_attentions.shape[1], 1, 1).view(-1, obj_num, pos.shape[-1])) zero_one_mult = zero_one_mult.view( -1, gold_question_attentions.shape[-1]) q_att_filter = zero_one_mult.sum(1) > 2 (l_relevant, v_relevant, _, _), x_relevant = self._encoder( reshaped_questions[q_att_filter, :], zero_one_mult[q_att_filter, :], reshaped_visual_feat[q_att_filter, :, :], reshaped_pos[q_att_filter, :, :], ) l = [{} for _ in range(batch_size)] v = [{} for _ in range(batch_size)] x = [{} for _ in range(batch_size)] count = 0 batch_index = -1 for i in range(zero_one_mult.shape[0]): module_num = i % target_action_sequence.shape[1] if module_num == 0: batch_index += 1 if q_att_filter[i].item(): l[batch_index][module_num] = l_relevant[count] v[batch_index][module_num] = v_relevant[count] x[batch_index][module_num] = x_relevant[count] count += 1 else: l = l_orig v = v_orig x = x_orig for batch_index in range(batch_size): if (self.training and self._training_batches_so_far < self._num_parse_only_batches): continue if not final_states[batch_index]: logger.error(f"No pogram found for batch index {batch_index}") outputs["best_action_sequence"].append([]) outputs["debug_info"].append([]) continue world = VisualReasoningGqaLanguage( l[batch_index], v[batch_index], x[batch_index], # initial_state.rnn_state[batch_index].encoder_outputs[batch_index], self._language_parameters, pos[batch_index], nmn_settings=self._nmn_settings, ) denotation_log_prob_list = [] # TODO(mattg): maybe we want to limit the number of states we evaluate (programs we # execute) at test time, just for efficiency. for state_index, state in enumerate(final_states[batch_index]): action_indices = state.action_history[0] action_strings = [ action_mapping[action_index] for action_index in action_indices ] # Shape: (num_denotations,) assert len(action_strings) == len(state.debug_info[0]) # Plug in gold question attentions for i in range(len(state.debug_info[0])): if gold_question_attentions[batch_index, i, :].sum() > 0: state.debug_info[0][i]["question_attention"] = ( gold_question_attentions[batch_index, i, :].float() / gold_question_attentions[batch_index, i, :].sum()) elif self._nmn_settings["mask_non_attention"] and ( action_strings[i][-4:] == "find" or action_strings[i][-6:] == "filter" or action_strings[i][-13:] == "with_relation"): state.debug_info[0][i]["question_attention"] = ( torch.ones_like( gold_question_attentions[batch_index, i, :]).float() / gold_question_attentions[batch_index, i, :].numel()) l[batch_index][i] = l_orig[batch_index] v[batch_index][i] = v_orig[batch_index] x[batch_index][i] = x_orig[batch_index] world = VisualReasoningGqaLanguage( l[batch_index], v[batch_index], x[batch_index], # initial_state.rnn_state[batch_index].encoder_outputs[batch_index], self._language_parameters, pos[batch_index], nmn_settings=self._nmn_settings, ) # print(action_strings) state_denotation_log_probs = world.execute_action_sequence( action_strings, state.debug_info[0]) # prob2 = world.execute_action_sequence(action_strings, state.debug_info[0]) # P(denotation | parse) * P(parse | question) denotation_log_prob_list.append(state_denotation_log_probs) if not self._use_gold_program_for_eval: denotation_log_prob_list[-1] += state.score[0] if state_index == 0: outputs["best_action_sequence"].append(action_strings) outputs["debug_info"].append(state.debug_info[0]) if target_action_sequence is not None: targets = target_action_sequence[batch_index].data program_correct = self._action_history_match( action_indices, targets) self._program_accuracy(program_correct) # P(denotation | parse) * P(parse | question) for the all programs on the beam. # Shape: (beam_size, num_denotations) denotation_log_probs = torch.stack(denotation_log_prob_list) # \Sum_parse P(denotation | parse) * P(parse | question) = P(denotation | question) # Shape: (num_denotations,) marginalized_denotation_log_probs = util.logsumexp( denotation_log_probs, dim=0) if denotation is not None: loss = (self.loss( state_denotation_log_probs.unsqueeze(0), denotation[batch_index].unsqueeze(0).float(), ).view(1) * self._denotation_loss_multiplier) losses.append(loss) self._denotation_accuracy( torch.tensor([ 1 - state_denotation_log_probs, state_denotation_log_probs ]).to(denotation.device), denotation[batch_index], ) if gold_object_choices is not None: gold_objects = gold_object_choices[batch_index, :, :] predicted_objects = torch.zeros_like(gold_objects) for index in world.object_scores: predicted_objects[ index, :] = world.object_scores[index] obj_exists = gold_objects.max(1)[0] > 0 # Only look at modules where at least one of the proposals has the object of interest predicted_objects = predicted_objects[obj_exists, :] gold_objects = gold_objects[obj_exists, :] gold_objects = gold_objects.view(-1) predicted_objects = predicted_objects.view(-1) if gold_objects.numel() > 0: loss += self._obj_loss_multiplier * self.loss( predicted_objects, (gold_objects.float() + 1) / 2) self._proposal_accuracy( torch.cat( ( 1.0 - predicted_objects.view(-1, 1), predicted_objects.view(-1, 1), ), dim=-1, ), (gold_objects + 1) / 2, ) if losses: outputs["loss"] = torch.stack(losses).mean() if self.training: self._training_batches_so_far += 1 return outputs
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) if self._use_gold_mentions: if text_embeddings.is_cuda: device = torch.device('cuda') else: device = torch.device('cpu') s = [ torch.as_tensor(pair, dtype=torch.long, device=device) for cluster in metadata[0]["clusters"] for pair in cluster ] gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1) span_mask = (spans.unsqueeze(2) - gm) span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1] == 0) span_mask, _ = (span_mask == 2).max(-1) num_spans = span_mask.sum().item() span_mask = span_mask.float() else: span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() num_spans = spans.size(1) # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) predicted_antecedents -= 1 output_dict = { "top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents } if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) coreference_log_probs = util.last_dim_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], question_and_passage: Dict[str, torch.LongTensor], answer_as_passage_spans: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ, unused-argument # logger.info("="*10) # logger.info([len(metadata[i]["passage_tokens"]) for i in range(len(metadata))]) # logger.info([len(metadata[i]["question_tokens"]) for i in range(len(metadata))]) # logger.info(question_and_passage["bert"].shape) # The segment labels should be as following: # <CLS> + question_word_pieces + <SEP> + passage_word_pieces + <SEP> # 0 0 0 1 1 # We get this in a tricky way here expanded_question_bert_tensor = torch.zeros_like( question_and_passage["bert"]) expanded_question_bert_tensor[:, :question["bert"]. shape[1]] = question["bert"] segment_labels = (question_and_passage["bert"] - expanded_question_bert_tensor > 0).long() question_and_passage["segment_labels"] = segment_labels embedded_question_and_passage = self._text_field_embedder( question_and_passage) # We also get the passage mask for the concatenated question and passage in a similar way expanded_question_mask = torch.zeros_like(question_and_passage["mask"]) # We shift the 1s to one column right here, to mask the [SEP] token in the middle expanded_question_mask[:, 1:question["mask"].shape[1] + 1] = question["mask"] expanded_question_mask[:, 0] = 1 passage_mask = question_and_passage["mask"] - expanded_question_mask batch_size = embedded_question_and_passage.size(0) span_start_logits = self._span_start_predictor( embedded_question_and_passage).squeeze(-1) span_end_logits = self._span_end_predictor( embedded_question_and_passage).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax( span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( span_end_logits, passage_mask) passage_span_start_logits = util.replace_masked_values( span_start_logits, passage_mask, -1e32) passage_span_end_logits = util.replace_masked_values( span_end_logits, passage_mask, -1e32) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) output_dict = { "passage_span_start_probs": passage_span_start_log_probs.exp(), "passage_span_end_probs": passage_span_end_log_probs.exp() } # If answer is given, compute the loss for training. if answer_as_passage_spans is not None: # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_as_passage_spans[:, :, 0] gold_passage_span_ends = answer_as_passage_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = (gold_passage_span_starts != -1).long() clamped_gold_passage_span_starts = util.replace_masked_values( gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = util.replace_masked_values( gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = \ torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_passage_span_ends = \ torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = \ log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = \ util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e32) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp( log_likelihood_for_passage_spans) output_dict[ "loss"] = -log_marginal_likelihood_for_passage_span.mean() # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) # We did not consider multi-mention answers here passage_str = metadata[i]['original_passage'] offsets = metadata[i]['passage_token_offsets'] predicted_span = tuple( best_passage_span[i].detach().cpu().numpy()) # Remove the offsets of question tokens and the [SEP] token predicted_span = (predicted_span[0] - len(metadata[i]['question_tokens']) - 1, predicted_span[1] - len(metadata[i]['question_tokens']) - 1) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_answer_str = passage_str[start_offset:end_offset] output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(best_answer_str) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(best_answer_str, answer_annotations) return output_dict
def forward( self, # type: ignore question_passage: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, mask_indices: torch.LongTensor, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_expressions: torch.LongTensor = None, answer_as_expressions_extra: torch.LongTensor = None, answer_as_unit_spans: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, answer_as_text_to_disjoint_bios: torch.LongTensor = None, answer_as_list_of_bios: torch.LongTensor = None, answer_as_yesno: torch.LongTensor = None, span_bio_labels: torch.LongTensor = None, bio_wordpiece_mask: torch.LongTensor = None, is_bio_mask: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, seqlen) question_passage_tokens = question_passage["tokens"] # Shape: (batch_size, seqlen) pad_mask = question_passage["mask"] # Shape: (batch_size, seqlen) seqlen_ids = question_passage["tokens-type-ids"] max_seqlen = question_passage_tokens.shape[-1] batch_size = question_passage_tokens.shape[0] # Shape: (batch_size, 3) mask = mask_indices.squeeze(-1) # Shape: (batch_size, seqlen) cls_sep_mask = \ torch.ones(pad_mask.shape, device=pad_mask.device).long().scatter(1, mask, torch.zeros(mask.shape, device=mask.device).long()) # Shape: (batch_size, seqlen) passage_mask = seqlen_ids * pad_mask * cls_sep_mask # Shape: (batch_size, seqlen) question_mask = (1 - seqlen_ids) * pad_mask * cls_sep_mask question_and_passage_mask = question_mask | passage_mask if bio_wordpiece_mask is None or not self.multispan_use_bio_wordpiece_mask: multispan_mask = question_and_passage_mask else: multispan_mask = question_and_passage_mask * bio_wordpiece_mask # Shape: (batch_size, seqlen, bert_dim) bert_out, _ = self.BERT(question_passage_tokens, seqlen_ids, pad_mask) # Shape: (batch_size, qlen, bert_dim) question_end = max(mask[:, 1]) question_out = bert_out[:, :question_end] # Shape: (batch_size, qlen) question_mask = question_mask[:, :question_end] # Shape: (batch_size, out) question_vector = self.summary_vector(question_out, question_mask, "question") passage_out = bert_out del bert_out # Shape: (batch_size, bert_dim) passage_vector = self.summary_vector(passage_out, passage_mask) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = \ self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1)) answer_ability_log_probs = torch.nn.functional.log_softmax( answer_ability_logits, -1) best_answer_ability = torch.argmax(answer_ability_log_probs, 1) top_two_answer_abilities = torch.topk(answer_ability_log_probs, k=2, dim=1) if "counting" in self.answering_abilities: count_number_log_probs, best_count_number = self._count_module( passage_vector) if "passage_span_extraction" in self.answering_abilities: passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span = \ self._passage_span_module(passage_out, passage_mask) if "question_span_extraction" in self.answering_abilities: question_span_start_log_probs, question_span_end_log_probs, best_question_span = \ self._question_span_module(passage_vector, question_out, question_mask) if "arithmetic" in self.answering_abilities or "counting" in self.answering_abilities: unit_span_start_log_probs, unit_span_end_log_probs, best_unit_span = \ self._unit_span_module(passage_vector, question_out, question_mask) if "multiple_spans" in self.answering_abilities: if self.multispan_head_name == "flexible_loss": multispan_log_probs, multispan_logits = self._multispan_module( passage_out, seq_mask=multispan_mask) else: multispan_log_probs, multispan_logits = self._multispan_module( passage_out) if "arithmetic" in self.answering_abilities: number_mask = (number_indices[:, :, 0].long() != -1).long() number_sign_log_probs, best_signs_for_numbers, number_mask = \ self._base_arithmetic_module(passage_vector, passage_out, number_indices, number_mask) if "yesno" in self.answering_abilities: yesno_log_probs, best_yesno = self._yesno_module(passage_vector) output_dict = {} del passage_out, question_out # If answer is given, compute the loss. if answer_as_passage_spans is not None or answer_as_question_spans is not None \ or answer_as_expressions is not None or answer_as_counts is not None \ or answer_as_yesno is not None or span_bio_labels is not None: log_marginal_likelihood_list = [] ### log_marginal_likelihood_for_unit_span = \ self._question_span_log_likelihood(answer_as_unit_spans, unit_span_start_log_probs, unit_span_end_log_probs) ### for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": log_marginal_likelihood_for_passage_span = \ self._passage_span_log_likelihood(answer_as_passage_spans, passage_span_start_log_probs, passage_span_end_log_probs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_passage_span) elif answering_ability == "question_span_extraction": log_marginal_likelihood_for_question_span = \ self._question_span_log_likelihood(answer_as_question_spans, question_span_start_log_probs, question_span_end_log_probs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_question_span) elif answering_ability == "arithmetic": log_marginal_likelihood_for_arithmetic = \ self._base_arithmetic_log_likelihood(answer_as_expressions, number_sign_log_probs, number_mask, answer_as_expressions_extra) log_marginal_likelihood_list.append( log_marginal_likelihood_for_arithmetic + log_marginal_likelihood_for_unit_span * 0.5) elif answering_ability == "counting": log_marginal_likelihood_for_count = \ self._count_log_likelihood(answer_as_counts, count_number_log_probs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_count + log_marginal_likelihood_for_unit_span * 0.5) elif answering_ability == "multiple_spans": if self.multispan_head_name == "flexible_loss": log_marginal_likelihood_for_multispan = \ self._multispan_log_likelihood(answer_as_text_to_disjoint_bios, answer_as_list_of_bios, span_bio_labels, multispan_log_probs, multispan_logits, multispan_mask, bio_wordpiece_mask, is_bio_mask) else: log_marginal_likelihood_for_multispan = \ self._multispan_log_likelihood(span_bio_labels, multispan_log_probs, multispan_mask, is_bio_mask, logits=multispan_logits) log_marginal_likelihood_list.append( log_marginal_likelihood_for_multispan) elif answering_ability == "yesno": log_marginal_likelihood_for_yesno = \ self._yesno_log_likelihood(answer_as_yesno, yesno_log_probs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_yesno) else: raise ValueError( f"Unsupported answering ability: {answering_ability}") if len(self.answering_abilities) > 1: # import pdb; pdb.set_trace() # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack( log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs marginal_log_likelihood = util.logsumexp( all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = -marginal_log_likelihood.mean() with torch.no_grad(): # Compute the metrics and add the tokenized input to the output. if metadata is not None: if not self.training: output_dict["passage_id"] = [] output_dict["query_id"] = [] output_dict["answer"] = [] output_dict["predicted_ability"] = [] output_dict["maximizing_ground_truth"] = [] output_dict["em"] = [] output_dict["f1"] = [] output_dict["invalid_spans"] = [] output_dict["max_passage_length"] = [] i = 0 while i < batch_size: if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[ best_answer_ability[i]] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} invalid_spans = [] q_text = metadata[i]['original_question'] p_text = metadata[i]['original_passage'] qp_tokens = metadata[i]['question_passage_tokens'] ### if predicted_ability_str == "passage_span_extraction": answer_json["answer_type"] = "passage_span" answer_json["value"], answer_json["spans"] = \ self._span_prediction(qp_tokens, best_passage_span[i], p_text, q_text, 'p') elif predicted_ability_str == "question_span_extraction": answer_json["answer_type"] = "question_span" answer_json["value"], answer_json["spans"] = \ self._span_prediction(qp_tokens, best_question_span[i], p_text, q_text, 'q') # import pdb; pdb.set_trace() elif predicted_ability_str == "arithmetic": # plus_minus combination answer answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]['original_numbers'] answer_json["value"], answer_json["numbers"] = \ self._base_arithmetic_prediction(original_numbers, number_indices[i], best_signs_for_numbers[i]) elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" answer_json["value"], answer_json["count"] = \ self._count_prediction(best_count_number[i]) elif predicted_ability_str == "multiple_spans": answer_json["answer_type"] = "multiple_spans" if self.multispan_head_name == "flexible_loss": answer_json["value"], answer_json["spans"], invalid_spans = \ self._multispan_prediction(multispan_log_probs[i], multispan_logits[i], qp_tokens, p_text, q_text, multispan_mask[i], bio_wordpiece_mask[i], self.multispan_use_prediction_beam_search and not self.training) else: answer_json["value"], answer_json["spans"], invalid_spans = \ self._multispan_prediction(multispan_log_probs[i], multispan_logits[i], qp_tokens, p_text, q_text, multispan_mask[i]) if self._unique_on_multispan: answer_json["value"] = list( OrderedDict.fromkeys(answer_json["value"])) if self._dont_add_substrings_to_ms: answer_json[ "value"] = remove_substring_from_prediction( answer_json["value"]) if len(answer_json["value"]) == 0: best_answer_ability[ i] = top_two_answer_abilities.indices[i][1] continue elif predicted_ability_str == "yesno": answer_json["answer_type"] = "yesno" answer_json["value"], answer_json["yesno"] = \ self._yesno_prediction(best_yesno[i]) else: raise ValueError( f"Unsupported answer ability: {predicted_ability_str}" ) if predicted_ability_str == "counting" or predicted_ability_str == "arithmetic": answer_json["unit_value"], answer_json["unit_spans"] = \ self._span_prediction(qp_tokens, best_unit_span[i], p_text, q_text, 'q') answer_json["value"] = answer_json[ "value"] + answer_json["unit_value"] maximizing_ground_truth = None em, f1 = None, None answer_annotations = metadata[i].get( 'answer_annotations', []) if answer_annotations: (em, f1 ), maximizing_ground_truth = self._drop_metrics.call( answer_json["value"], [ dict((key, answer_annotation[key] ) if key != 'number' else ( key, answer_annotation[key] + answer_annotation['unit']) for key in answer_annotation.keys()) for answer_annotation in answer_annotations ], predicted_ability_str) if not self.training: output_dict["passage_id"].append( metadata[i]["passage_id"]) output_dict["query_id"].append( metadata[i]["question_id"]) output_dict["answer"].append(answer_json) output_dict["predicted_ability"].append( predicted_ability_str) output_dict["maximizing_ground_truth"].append( maximizing_ground_truth) output_dict["em"].append(em) output_dict["f1"].append(f1) output_dict["invalid_spans"].append(invalid_spans) output_dict["max_passage_length"].append( metadata[i]["max_passage_length"]) i += 1 return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, doc_span_offsets: torch.IntTensor, span_labels: torch.IntTensor = None, doc_truth_spans: torch.IntTensor = None, doc_spans_in_truth: torch.IntTensor = None, doc_relation_labels: torch.Tensor = None, truth_spans: List[Set[Tuple[int, int]]] = None, # doc_relations = None, doc_ner_labels: torch.IntTensor = None, **metadata: Dict[str, List[Any]] ) -> Dict[str, torch.Tensor]: # add matrix from datareader # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. metadata : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. doc_ner_labels : ``torch.IntTensor``. A tensor of shape # TODO, ... doc_span_offsets : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_truth_spans : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, 1), ... doc_spans_in_truth : ``torch.IntTensor``. A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1), ... doc_relation_labels : ``torch.Tensor``. A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans), ... Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) batch_size = len(spans) document_length = text_embeddings.size(1) max_sentence_length = max( len(sentence_tokens) for doc_tokens in metadata['doc_tokens'] for sentence_tokens in doc_tokens) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # TODO features dropout # Shape: (batch_size, num_spans, embedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int( math.floor(self._spans_per_word * document_length)) num_relex_spans_to_keep = int( math.floor(self._relex_spans_per_word * max_sentence_length)) # Shapes: # (batch_size, num_spans_to_keep, span_dim), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep), # (batch_size, num_spans_to_keep, 1) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) # Shape: (batch_size, num_spans_to_keep, 1) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = dict() # Store raw text and tokens for decoding step output_dict["flat_tokens"] = metadata["flat_tokens"] output_dict["flat_text"] = metadata["flat_text"] output_dict["top_spans"] = top_spans output_dict["antecedent_indices"] = valid_antecedent_indices output_dict["predicted_antecedents"] = predicted_antecedents if metadata is not None: output_dict["document"] = metadata["original_text"] # Shape: (,) loss = 0 # Shape: (batch_size, max_sentences, max_spans) doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float() # Shape: (batch_size, max_sentences, num_spans, span_dim) doc_span_embeddings = util.batched_index_select( span_embeddings, doc_span_offsets.squeeze(-1).long().clamp(min=0)) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep), # (batch_size, max_sentences, num_relex_spans_to_keep, 1) pruned = self._relex_mention_pruner( doc_span_embeddings, doc_span_mask, num_items_to_keep=num_relex_spans_to_keep, pass_through=['num_items_to_keep']) (top_relex_span_embeddings, top_relex_span_mask, top_relex_span_indices, top_relex_span_mention_scores) = pruned # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1) top_relex_span_mask = top_relex_span_mask.unsqueeze(-1) # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2) # TODO do we need for a mask? doc_spans = util.batched_index_select( spans, doc_span_offsets.clamp(0).squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2) top_relex_spans = nd_batched_index_select(doc_spans, top_relex_span_indices) # Shapes: # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim), # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep). (relex_span_pair_embeddings, relex_span_pair_mask) = self._compute_relex_span_pair_embeddings( top_relex_span_embeddings, top_relex_span_mask.squeeze(-1)) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels) relex_scores = self._compute_relex_scores( relex_span_pair_embeddings, top_relex_span_mention_scores) output_dict['relex_scores'] = relex_scores output_dict['top_relex_spans'] = top_relex_spans if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels_ = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long( ) # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability x to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs) negative_marginal_log_likelihood *= top_span_mask.squeeze( -1).float() negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum( ) # TODO Modify metadata format # self._mention_recall(top_spans, metadata) # self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) coref_loss = negative_marginal_log_likelihood output_dict['coref_loss'] = coref_loss loss += self._loss_coref_weight * coref_loss if doc_relation_labels is not None: # The adjacency matrix for relation extraction is very sparse. # As it is not just sparse, but row/column sparse (only few # rows and columns are non-zero and in that case these rows/columns # are not sparse), we implemented our own matrix for the case. # Here we have indices of truth spans and mapping, using which # we map prediction matrix on truth matrix. # TODO Add teacher forcing support. # Shape: (batch_size, max_sentences, num_relex_spans_to_keep), relative_indices = top_relex_span_indices # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1), compressed_indices = nd_batched_padded_index_select( doc_spans_in_truth, relative_indices) # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans) gold_pruned_rows = nd_batched_padded_index_select( doc_relation_labels, compressed_indices.squeeze(-1), padding_value=0) gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3, 2).contiguous() # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep) gold_pruned_matrices = nd_batched_padded_index_select( gold_pruned_rows, compressed_indices.squeeze(-1), padding_value=0) # pad with epsilon gold_pruned_matrices = gold_pruned_matrices.permute( 0, 1, 3, 2).contiguous() # TODO log_mask relex score before passing relex_loss = nd_cross_entropy_with_logits(relex_scores, gold_pruned_matrices, relex_span_pair_mask) output_dict['relex_loss'] = relex_loss self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2), truth_spans) # To calculate F1 score, we need to to call decode step output_dict = self.decode(output_dict) self._compute_relex_metrics(output_dict['raw_interactions'], metadata['doc_raw_relations']) loss += self._loss_relex_weight * relex_loss if doc_ner_labels is not None: # Shape: (batch_size, max_sentences, num_spans, num_ner_classes) ner_scores = self._ner_scorer(doc_span_embeddings) output_dict['ner_scores'] = ner_scores ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels, doc_span_mask) output_dict['ner_loss'] = ner_loss loss += self._loss_ner_weight * ner_loss if not isinstance(loss, int): # If loss is not yet modified output_dict["loss"] = loss return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_add_sub_expressions: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer( self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer( self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer( embedded_question) projected_embedded_passage = self._encoding_proj_layer( embedded_passage) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_mask)) encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) passsage_attention_over_attention = torch.bmm( passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum( encoded_passage, passsage_attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors ], dim=-1)) # The recurrent modeling layers. Since these layers share the same parameters, # we don't construct them conditioned on answering abilities. modeled_passage_list = [ self._modeling_proj_layer(merged_passage_attention_vectors) ] for _ in range(4): modeled_passage = self._dropout( self._modeling_layer(modeled_passage_list[-1], passage_mask)) modeled_passage_list.append(modeled_passage) # Pop the first one, which is input modeled_passage_list.pop(0) # The first modeling layer is used to calculate the vector representation of passage passage_weights = self._passage_weights_predictor( modeled_passage_list[0]).squeeze(-1) passage_weights = masked_softmax(passage_weights, passage_mask) passage_vector = util.weighted_sum(modeled_passage_list[0], passage_weights) # The vector representation of question is calculated based on the unmatched encoding, # because we may want to infer the answer ability only based on the question words. question_weights = self._question_weights_predictor( encoded_question).squeeze(-1) question_weights = masked_softmax(question_weights, question_mask) question_vector = util.weighted_sum(encoded_question, question_weights) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = \ self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1)) answer_ability_log_probs = torch.nn.functional.log_softmax( answer_ability_logits, -1) best_answer_ability = torch.argmax(answer_ability_log_probs, 1) if "counting" in self.answering_abilities: # Shape: (batch_size, 10) count_number_logits = self._count_number_predictor(passage_vector) count_number_log_probs = torch.nn.functional.log_softmax( count_number_logits, -1) # Info about the best count number prediction # Shape: (batch_size,) best_count_number = torch.argmax(count_number_log_probs, -1) best_count_log_prob = \ torch.gather(count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1) if len(self.answering_abilities) > 1: best_count_log_prob += answer_ability_log_probs[:, self. _counting_index] if "passage_span_extraction" in self.answering_abilities: # Shape: (batch_size, passage_length, modeling_dim * 2)) passage_for_span_start = torch.cat( [modeled_passage_list[0], modeled_passage_list[1]], dim=-1) # Shape: (batch_size, passage_length) passage_span_start_logits = self._passage_span_start_predictor( passage_for_span_start).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) passage_for_span_end = torch.cat( [modeled_passage_list[0], modeled_passage_list[2]], dim=-1) # Shape: (batch_size, passage_length) passage_span_end_logits = self._passage_span_end_predictor( passage_for_span_end).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax( passage_span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( passage_span_end_logits, passage_mask) # Info about the best passage span prediction passage_span_start_logits = util.replace_masked_values( passage_span_start_logits, passage_mask, -1e7) passage_span_end_logits = util.replace_masked_values( passage_span_end_logits, passage_mask, -1e7) # Shape: (batch_size, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) # Shape: (batch_size, 2) best_passage_start_log_probs = \ torch.gather(passage_span_start_log_probs, 1, best_passage_span[:, 0].unsqueeze(-1)).squeeze(-1) best_passage_end_log_probs = \ torch.gather(passage_span_end_log_probs, 1, best_passage_span[:, 1].unsqueeze(-1)).squeeze(-1) # Shape: (batch_size,) best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs if len(self.answering_abilities) > 1: best_passage_span_log_prob += answer_ability_log_probs[:, self. _passage_span_extraction_index] if "question_span_extraction" in self.answering_abilities: # Shape: (batch_size, question_length) encoded_question_for_span_prediction = \ torch.cat([encoded_question, passage_vector.unsqueeze(1).repeat(1, encoded_question.size(1), 1)], -1) question_span_start_logits = \ self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1) # Shape: (batch_size, question_length) question_span_end_logits = \ self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1) question_span_start_log_probs = util.masked_log_softmax( question_span_start_logits, question_mask) question_span_end_log_probs = util.masked_log_softmax( question_span_end_logits, question_mask) # Info about the best question span prediction question_span_start_logits = \ util.replace_masked_values(question_span_start_logits, question_mask, -1e7) question_span_end_logits = \ util.replace_masked_values(question_span_end_logits, question_mask, -1e7) # Shape: (batch_size, 2) best_question_span = get_best_span(question_span_start_logits, question_span_end_logits) # Shape: (batch_size, 2) best_question_start_log_probs = \ torch.gather(question_span_start_log_probs, 1, best_question_span[:, 0].unsqueeze(-1)).squeeze(-1) best_question_end_log_probs = \ torch.gather(question_span_end_log_probs, 1, best_question_span[:, 1].unsqueeze(-1)).squeeze(-1) # Shape: (batch_size,) best_question_span_log_prob = best_question_start_log_probs + best_question_end_log_probs if len(self.answering_abilities) > 1: best_question_span_log_prob += answer_ability_log_probs[:, self. _question_span_extraction_index] if "addition_subtraction" in self.answering_abilities: # Shape: (batch_size, # of numbers in the passage) number_indices = number_indices.squeeze(-1) number_mask = (number_indices != -1).long() clamped_number_indices = util.replace_masked_values( number_indices, number_mask, 0) encoded_passage_for_numbers = torch.cat( [modeled_passage_list[0], modeled_passage_list[3]], dim=-1) # Shape: (batch_size, # of numbers in the passage, encoding_dim) encoded_numbers = torch.gather( encoded_passage_for_numbers, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, encoded_passage_for_numbers.size(-1))) # Shape: (batch_size, # of numbers in the passage) encoded_numbers = torch.cat([ encoded_numbers, passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1) ], -1) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor(encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax( number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values( best_signs_for_numbers, number_mask, 0) # Shape: (batch_size, # of numbers in passage) best_signs_log_probs = torch.gather( number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1) # the probs of the masked positions should be 1 so that it will not affect the joint probability # TODO: this is not quite right, since if there are many numbers in the passage, # TODO: the joint probability would be very small. best_signs_log_probs = util.replace_masked_values( best_signs_log_probs, number_mask, 0) # Shape: (batch_size,) best_combination_log_prob = best_signs_log_probs.sum(-1) if len(self.answering_abilities) > 1: best_combination_log_prob += answer_ability_log_probs[:, self. _addition_subtraction_index] output_dict = {} # If answer is given, compute the loss. if answer_as_passage_spans is not None or answer_as_question_spans is not None \ or answer_as_add_sub_expressions is not None or answer_as_counts is not None: log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_as_passage_spans[:, :, 0] gold_passage_span_ends = answer_as_passage_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = (gold_passage_span_starts != -1).long() clamped_gold_passage_span_starts = \ util.replace_masked_values(gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = \ util.replace_masked_values(gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = \ torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_passage_span_ends = \ torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = \ log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = \ util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp( log_likelihood_for_passage_spans) log_marginal_likelihood_list.append( log_marginal_likelihood_for_passage_span) elif answering_ability == "question_span_extraction": # Shape: (batch_size, # of answer spans) gold_question_span_starts = answer_as_question_spans[:, :, 0] gold_question_span_ends = answer_as_question_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_question_span_mask = (gold_question_span_starts != -1).long() clamped_gold_question_span_starts = \ util.replace_masked_values(gold_question_span_starts, gold_question_span_mask, 0) clamped_gold_question_span_ends = \ util.replace_masked_values(gold_question_span_ends, gold_question_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_span_starts = \ torch.gather(question_span_start_log_probs, 1, clamped_gold_question_span_starts) log_likelihood_for_question_span_ends = \ torch.gather(question_span_end_log_probs, 1, clamped_gold_question_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_spans = \ log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_question_spans = \ util.replace_masked_values(log_likelihood_for_question_spans, gold_question_span_mask, -1e7) # Shape: (batch_size, ) # pylint: disable=invalid-name log_marginal_likelihood_for_question_span = \ util.logsumexp(log_likelihood_for_question_spans) log_marginal_likelihood_list.append( log_marginal_likelihood_for_question_span) elif answering_ability == "addition_subtraction": # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here. # Shape: (batch_size, # of combinations) gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1) > 0).float() # Shape: (batch_size, # of numbers in the passage, # of combinations) gold_add_sub_signs = answer_as_add_sub_expressions.transpose( 1, 2) # Shape: (batch_size, # of numbers in the passage, # of combinations) log_likelihood_for_number_signs = torch.gather( number_sign_log_probs, 2, gold_add_sub_signs) # the log likelihood of the masked positions should be 0 # so that it will not affect the joint probability log_likelihood_for_number_signs = \ util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0) # Shape: (batch_size, # of combinations) log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum( 1) # For those padded combinations, we set their log probabilities to be very small negative value log_likelihood_for_add_subs = \ util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_add_sub = util.logsumexp( log_likelihood_for_add_subs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_add_sub) elif answering_ability == "counting": # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = (answer_as_counts != -1).long() # Shape: (batch_size, # of count answers) clamped_gold_counts = util.replace_masked_values( answer_as_counts, gold_count_mask, 0) log_likelihood_for_counts = torch.gather( count_number_log_probs, 1, clamped_gold_counts) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = \ util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_count = util.logsumexp( log_likelihood_for_counts) log_marginal_likelihood_list.append( log_marginal_likelihood_for_count) else: raise ValueError( f"Unsupported answering ability: {answering_ability}") if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack( log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs marginal_log_likelihood = util.logsumexp( all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = -marginal_log_likelihood.mean() # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[ best_answer_ability[i].detach().cpu().numpy()] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} # We did not consider multi-mention answers here if predicted_ability_str == "passage_span_extraction": answer_json["answer_type"] = "passage_span" passage_str = metadata[i]['original_passage'] offsets = metadata[i]['passage_token_offsets'] predicted_span = tuple( best_passage_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] predicted_answer = passage_str[start_offset:end_offset] answer_json["value"] = predicted_answer answer_json["spans"] = [(start_offset, end_offset)] elif predicted_ability_str == "question_span_extraction": answer_json["answer_type"] = "question_span" question_str = metadata[i]['original_question'] offsets = metadata[i]['question_token_offsets'] predicted_span = tuple( best_question_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] predicted_answer = question_str[start_offset:end_offset] answer_json["value"] = predicted_answer answer_json["spans"] = [(start_offset, end_offset)] elif predicted_ability_str == "addition_subtraction": # plus_minus combination answer answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]['original_numbers'] sign_remap = {0: 0, 1: 1, 2: -1} predicted_signs = [ sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy() ] result = sum([ sign * number for sign, number in zip( predicted_signs, original_numbers) ]) predicted_answer = str(result) offsets = metadata[i]['passage_token_offsets'] number_indices = metadata[i]['number_indices'] number_positions = [ offsets[index] for index in number_indices ] answer_json['numbers'] = [] for offset, value, sign in zip(number_positions, original_numbers, predicted_signs): answer_json['numbers'].append({ 'span': offset, 'value': value, 'sign': sign }) if number_indices[-1] == -1: # There is a dummy 0 number at position -1 added in some cases; we are # removing that here. answer_json["numbers"].pop() answer_json["value"] = result elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" predicted_count = best_count_number[i].detach().cpu( ).numpy() predicted_answer = str(predicted_count) answer_json["count"] = predicted_count else: raise ValueError( f"Unsupported answer ability: {predicted_ability_str}") output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(answer_json) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(predicted_answer, answer_annotations) # This is used for the demo. output_dict[ "passage_question_attention"] = passage_question_attention output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens return output_dict
def forward(self, # type: ignore question_passage: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, mask_indices: torch.LongTensor, num_spans: torch.LongTensor = None, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_expressions: torch.LongTensor = None, answer_as_expressions_extra: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ # Shape: (batch_size, seqlen) question_passage_tokens = question_passage["tokens"] # Shape: (batch_size, seqlen) pad_mask = question_passage["mask"] # Shape: (batch_size, seqlen) seqlen_ids = question_passage["tokens-type-ids"] max_seqlen = question_passage_tokens.shape[-1] batch_size = question_passage_tokens.shape[0] # Shape: (batch_size, 3) mask = mask_indices.squeeze(-1) # Shape: (batch_size, seqlen) cls_sep_mask = \ torch.ones(pad_mask.shape, device=pad_mask.device).long().scatter(1, mask, torch.zeros(mask.shape, device=mask.device).long()) # Shape: (batch_size, seqlen) passage_mask = seqlen_ids * pad_mask * cls_sep_mask # Shape: (batch_size, seqlen) question_mask = (1 - seqlen_ids) * pad_mask * cls_sep_mask # Shape: (batch_size, seqlen, bert_dim) bert_out, _ = self.BERT(question_passage_tokens, seqlen_ids, pad_mask, output_all_encoded_layers=False) # Shape: (batch_size, qlen, bert_dim) question_end = max(mask[:,1]) question_out = bert_out[:,:question_end] # Shape: (batch_size, qlen) question_mask = question_mask[:,:question_end] # Shape: (batch_size, out) question_vector = self.summary_vector(question_out, question_mask, "question") passage_out = bert_out del bert_out # Shape: (batch_size, bert_dim) passage_vector = self.summary_vector(passage_out, passage_mask) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = \ self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1)) answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1) best_answer_ability = torch.argmax(answer_ability_log_probs, 1) if "counting" in self.answering_abilities: count_passage_vector = self.summary_vector(passage_out, passage_mask, "count_passage") count_number_log_probs, best_count_number = self._count_module(count_passage_vector) if "passage_span_extraction" in self.answering_abilities: passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span = \ self._passage_span_module(passage_out, passage_mask) if "question_span_extraction" in self.answering_abilities: qspan_passage_vector = self.summary_vector(passage_out, passage_mask, "qspan_passage") question_span_start_log_probs, question_span_end_log_probs, best_question_span = \ self._question_span_module(qspan_passage_vector, question_out, question_mask) if "arithmetic" in self.answering_abilities: arithmetic_passage_vector = self.summary_vector(passage_out, passage_mask, "arithmetic_passage") arithmetic_question_vector = self.summary_vector(question_out, question_mask, "arithmetic_question") arithmetic_template_logits = \ self._arithmetic_template_predictor(torch.cat([arithmetic_passage_vector, arithmetic_question_vector], -1)) arithmetic_template_log_probs = arithmetic_template_logits.log_softmax(-1) arithmetic_best_templates = arithmetic_template_log_probs.argmax(-1) number_mask = (number_indices[:,:,0].long() != -1).long() arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask = \ self._arithmetic_module(arithmetic_passage_vector, passage_out, number_indices, number_mask) output_dict = {} del passage_out, question_out # If answer is given, compute the loss. if answer_as_passage_spans is not None or answer_as_question_spans is not None \ or answer_as_expressions is not None or answer_as_counts is not None: log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": log_marginal_likelihood_for_passage_span = \ self._passage_span_log_likelihood(answer_as_passage_spans, passage_span_start_log_probs, passage_span_end_log_probs) log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span) elif answering_ability == "question_span_extraction": log_marginal_likelihood_for_question_span = \ self._question_span_log_likelihood(answer_as_question_spans, question_span_start_log_probs, question_span_end_log_probs) log_marginal_likelihood_list.append(log_marginal_likelihood_for_question_span) elif answering_ability == "arithmetic": log_marginal_likelihood_for_arithmetic = \ self._arithmetic_log_likelihood(answer_as_expressions, arithmetic_template_slot_log_probs, arithmetic_template_log_probs) log_marginal_likelihood_list.append(log_marginal_likelihood_for_arithmetic) elif answering_ability == "counting": log_marginal_likelihood_for_count = \ self._count_log_likelihood(answer_as_counts, count_number_log_probs) log_marginal_likelihood_list.append(log_marginal_likelihood_for_count) else: raise ValueError(f"Unsupported answering ability: {answering_ability}") if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs marginal_log_likelihood = util.logsumexp(all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = - marginal_log_likelihood.mean() with torch.no_grad(): # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[best_answer_ability[i]] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} # We did not consider multi-mention answers here if predicted_ability_str == "passage_span_extraction": answer_json["answer_type"] = "passage_span" answer_json["value"], answer_json["spans"] = \ self._span_prediction(question_passage_tokens[i], best_passage_span[i]) elif predicted_ability_str == "question_span_extraction": answer_json["answer_type"] = "question_span" answer_json["value"], answer_json["spans"] = \ self._span_prediction(question_passage_tokens[i], best_question_span[i]) elif predicted_ability_str == "arithmetic": answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]['original_numbers'] answer_json["value"], answer_json["indices"], answer_json["numbers"] = \ self._arithmetic_prediction(original_numbers, arithmetic_best_templates[i], arithmetic_best_template_slots[i]) answer_json['template'] = arithmetic_best_templates[i].item() elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" answer_json["value"], answer_json["count"] = \ self._count_prediction(best_count_number[i]) else: raise ValueError(f"Unsupported answer ability: {predicted_ability_str}") output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(answer_json) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(answer_json["value"], answer_annotations) return output_dict
def _get_ll_contrib( self, generation_scores: torch.Tensor, generation_scores_mask: torch.Tensor, copy_scores: torch.Tensor, target_tokens: torch.Tensor, target_to_source: torch.Tensor, copy_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Get the log-likelihood contribution from a single timestep. Parameters ---------- generation_scores : ``torch.Tensor`` Shape: `(batch_size, target_vocab_size)` generation_scores_mask : ``torch.Tensor`` Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's. copy_scores : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` target_tokens : ``torch.Tensor`` Shape: `(batch_size,)` target_to_source : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` copy_mask : ``torch.Tensor`` Shape: `(batch_size, trimmed_source_length)` Returns ------- Tuple[torch.Tensor, torch.Tensor] Shape: `(batch_size,), (batch_size, max_input_sequence_length)` """ _, target_size = generation_scores.size() # The point of this mask is to just mask out all source token scores # that just represent padding. We apply the mask to the concatenation # of the generation scores and the copy scores to normalize the scores # correctly during the softmax. # shape: (batch_size, target_vocab_size + trimmed_source_length) mask = torch.cat((generation_scores_mask, copy_mask), dim=-1) # shape: (batch_size, target_vocab_size + trimmed_source_length) all_scores = torch.cat((generation_scores, copy_scores), dim=-1) # Normalize generation and copy scores. # shape: (batch_size, target_vocab_size + trimmed_source_length) log_probs = util.masked_log_softmax(all_scores, mask) # Calculate the log probability (`copy_log_probs`) for each token in the source sentence # that matches the current target token. We use the sum of these copy probabilities # for matching tokens in the source sentence to get the total probability # for the target token. We also need to normalize the individual copy probabilities # to create `selective_weights`, which are used in the next timestep to create # a selective read state. # shape: (batch_size, trimmed_source_length) copy_log_probs = log_probs[:, target_size:] + ( target_to_source.float() + 1e-45).log() # Since `log_probs[:, target_size]` gives us the raw copy log probabilities, # we use a non-log softmax to get the normalized non-log copy probabilities. selective_weights = util.masked_softmax(log_probs[:, target_size:], target_to_source) # This mask ensures that item in the batch has a non-zero generation probabilities # for this timestep only when the gold target token is not OOV or there are no # matching tokens in the source sentence. # shape: (batch_size, 1) gen_mask = ((target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)).float() log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1) # Now we get the generation score for the gold target token. # shape: (batch_size, 1) print(target_tokens.unsqueeze(1), target_tokens.unsqueeze(1).size()) print(log_probs, log_probs.size()) print(log_probs.gather(1, target_tokens.unsqueeze(1)), log_probs.gather(1, target_tokens.unsqueeze(1)).size()) generation_log_probs = log_probs.gather( 1, target_tokens.unsqueeze(1)) + log_gen_mask # ... and add the copy score to get the step log likelihood. # shape: (batch_size, 1 + trimmed_source_length) combined_gen_and_copy = torch.cat( (generation_log_probs, copy_log_probs), dim=-1) # shape: (batch_size,) step_log_likelihood = util.logsumexp(combined_gen_and_copy) return step_log_likelihood, selective_weights
def forward( self, # type: ignore passage_question: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, answer_type=None, answer_as_spans: torch.LongTensor = None, answer_as_add_sub_expressions: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: passage_question_mask = passage_question["mask"].float() embedded_passage_question = self._dropout( self._text_field_embedder(passage_question)) # Encode with bert batch_size = embedded_passage_question.size(0) encoded_passage_question = embedded_passage_question """ passage_vactor 用 [CLS]对应的代替 """ passage_question_vector = encoded_passage_question[:, 0] if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = \ self._answer_ability_predictor(passage_question_vector) answer_ability_log_probs = torch.nn.functional.log_softmax( answer_ability_logits, -1) #best_answer_ability = torch.argmax(answer_ability_log_probs, 1) if "counting" in self.answering_abilities: # Shape: (batch_size, 10) count_number_logits = self._count_number_predictor( passage_question_vector) count_number_log_probs = torch.nn.functional.log_softmax( count_number_logits, -1) # Info about the best count number prediction # Shape: (batch_size,) best_count_number = torch.argmax(count_number_log_probs, -1) best_count_log_prob = \ torch.gather(count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1) if len(self.answering_abilities) > 1: best_count_log_prob += answer_ability_log_probs[:, self. _counting_index] if "span_extraction" in self.answering_abilities: # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( encoded_passage_question).squeeze(-1) # Shape: (batch_size, passage_length) span_end_logits = self._span_end_predictor( encoded_passage_question).squeeze(-1) # Shape: (batch_size, passage_length) span_start_log_probs = util.masked_log_softmax( span_start_logits, passage_question_mask) span_end_log_probs = util.masked_log_softmax( span_end_logits, passage_question_mask) # Info about the best passage span prediction span_start_logits = util.replace_masked_values( span_start_logits, passage_question_mask, -1e7) span_end_logits = util.replace_masked_values( span_end_logits, passage_question_mask, -1e7) # Shape: (batch_size, 2) best_span = get_best_span(span_start_logits, span_end_logits) # Shape: (batch_size, 2) best_start_log_probs = \ torch.gather(span_start_log_probs, 1, best_span[:, 0].unsqueeze(-1)).squeeze(-1) best_end_log_probs = \ torch.gather(span_end_log_probs, 1, best_span[:, 1].unsqueeze(-1)).squeeze(-1) # Shape: (batch_size,) best_span_log_prob = best_start_log_probs + best_end_log_probs if len(self.answering_abilities) > 1: best_span_log_prob += answer_ability_log_probs[:, self. _span_extraction_index] if "addition_subtraction" in self.answering_abilities: # Shape: (batch_size, # of numbers in the passage) number_indices = number_indices.squeeze(-1) number_mask = (number_indices != -1).long() clamped_number_indices = util.replace_masked_values( number_indices, number_mask, 0) #encoded_passage_for_numbers = torch.cat([modeled_passage_list[0], modeled_passage_list[3]], dim=-1) # Shape: (batch_size, # of numbers in the passage, encoding_dim) encoded_numbers = torch.gather( encoded_passage_question, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, encoded_passage_question.size(-1))) #self._external_number_embedding = self._external_number_embedding.cuda(device) #encoded_numbers = self.self_attention(encoded_numbers,number_mask) encoded_numbers = self.Concat_attention(encoded_numbers, passage_question_vector, number_mask) # Shape: (batch_size, # of numbers in the passage) #encoded_numbers = torch.cat( # [encoded_numbers, passage_question_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor(encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax( number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values( best_signs_for_numbers, number_mask, 0) # Shape: (batch_size, # of numbers in passage) best_signs_log_probs = torch.gather( number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1) # the probs of the masked positions should be 1 so that it will not affect the joint probability # TODO: this is not quite right, since if there are many numbers in the passage, # TODO: the joint probability would be very small. best_signs_log_probs = util.replace_masked_values( best_signs_log_probs, number_mask, 0) # Shape: (batch_size,) if len(self.answering_abilities) > 1: # batch_size best_combination_log_prob = best_signs_log_probs.sum(-1) best_combination_log_prob += answer_ability_log_probs[:, self. _addition_subtraction_index] best_answer_ability = torch.argmax( torch.stack([ best_span_log_prob, best_combination_log_prob, best_count_log_prob ], -1), 1) output_dict = {} # If answer is given, compute the loss. if answer_as_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None: log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "span_extraction": # Shape: (batch_size, # of answer spans) gold_span_starts = answer_as_spans[:, :, 0] gold_span_ends = answer_as_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_span_mask = (gold_span_starts != -1).long() clamped_gold_span_starts = \ util.replace_masked_values(gold_span_starts, gold_span_mask, 0) clamped_gold_span_ends = \ util.replace_masked_values(gold_span_ends, gold_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_span_starts = \ torch.gather(span_start_log_probs, 1, clamped_gold_span_starts) log_likelihood_for_span_ends = \ torch.gather(span_end_log_probs, 1, clamped_gold_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_spans = \ log_likelihood_for_span_starts + log_likelihood_for_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_spans = \ util.replace_masked_values(log_likelihood_for_spans, gold_span_mask, -1e7) # Shape: (batch_size, ) # log_marginal_likelihood_for_span = torch.sum(log_likelihood_for_spans,-1) log_marginal_likelihood_for_span = util.logsumexp( log_likelihood_for_spans) log_marginal_likelihood_list.append( log_marginal_likelihood_for_span) elif answering_ability == "addition_subtraction": # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here. # Shape: (batch_size, # of combinations) gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1) > 0).float() # Shape: (batch_size, # of numbers in the passage, # of combinations) gold_add_sub_signs = answer_as_add_sub_expressions.transpose( 1, 2) # Shape: (batch_size, # of numbers in the passage, # of combinations) log_likelihood_for_number_signs = torch.gather( number_sign_log_probs, 2, gold_add_sub_signs) # the log likelihood of the masked positions should be 0 # so that it will not affect the joint probability log_likelihood_for_number_signs = \ util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0) # Shape: (batch_size, # of combinations) log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum( 1) # For those padded combinations, we set their log probabilities to be very small negative value log_likelihood_for_add_subs = \ util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7) # Shape: (batch_size, ) #log_marginal_likelihood_for_add_sub = torch.sum(log_likelihood_for_add_subs,-1) #log_marginal_likelihood_for_add_sub = util.logsumexp(log_likelihood_for_add_subs) #log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub) log_marginal_likelihood_for_add_sub = util.logsumexp( log_likelihood_for_add_subs) #log_marginal_likelihood_for_external = util.logsumexp(log_likelihood_for_externals) log_marginal_likelihood_list.append( log_marginal_likelihood_for_add_sub) elif answering_ability == "counting": # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = (answer_as_counts != -1).long() # Shape: (batch_size, # of count answers) clamped_gold_counts = util.replace_masked_values( answer_as_counts, gold_count_mask, 0) log_likelihood_for_counts = torch.gather( count_number_log_probs, 1, clamped_gold_counts) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = \ util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7) # Shape: (batch_size, ) #log_marginal_likelihood_for_count = torch.sum(log_likelihood_for_counts,-1) log_marginal_likelihood_for_count = util.logsumexp( log_likelihood_for_counts) log_marginal_likelihood_list.append( log_marginal_likelihood_for_count) else: raise ValueError( f"Unsupported answering ability: {answering_ability}") if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack( log_marginal_likelihood_list, dim=-1) loss_for_type = -(torch.sum( answer_ability_log_probs * answer_type, -1)).mean() loss_for_answer = -(torch.sum(all_log_marginal_likelihoods, -1)).mean() loss = loss_for_type + loss_for_answer else: marginal_log_likelihood = log_marginal_likelihood_list[0] loss = -marginal_log_likelihood.mean() output_dict["loss"] = loss # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] passage_question_tokens = [] for i in range(batch_size): passage_question_tokens.append( metadata[i]['passage_question_tokens']) if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[ best_answer_ability[i].detach().cpu().numpy()] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} # We did not consider multi-mention answers here if predicted_ability_str == "span_extraction": answer_json["answer_type"] = "span" passage_question_token = metadata[i][ 'passage_question_tokens'] #offsets = metadata[i]['passage_token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = predicted_span[0] end_offset = predicted_span[1] predicted_answer = " ".join([ token for token in passage_question_token[start_offset:end_offset + 1] if token != "[SEP]" ]).replace(" ##", "") answer_json["value"] = predicted_answer answer_json["spans"] = [(start_offset, end_offset)] elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" predicted_count = best_count_number[i].detach().cpu( ).numpy() predicted_answer = str(predicted_count) answer_json["count"] = predicted_count elif predicted_ability_str == "addition_subtraction": answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]['original_numbers'] sign_remap = {0: 0, 1: 1, 2: -1} predicted_signs = [ sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy() ] result = 0 for j, number in enumerate(original_numbers): sign = predicted_signs[j] if sign != 0: result += sign * number predicted_answer = str(result) #offsets = metadata[i]['passage_token_offsets'] number_indices = metadata[i]['number_indices'] #number_positions = [offsets[index] for index in number_indices] answer_json['numbers'] = [] for indice, value, sign in zip(number_indices, original_numbers, predicted_signs): answer_json['numbers'].append({ 'span': indice, 'value': str(value), 'sign': sign }) if number_indices[-1] == -1: # There is a dummy 0 number at position -1 added in some cases; we are # removing that here. answer_json["numbers"].pop() answer_json["value"] = str(result) else: raise ValueError( f"Unsupported answer ability: {predicted_ability_str}") output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(answer_json) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(predicted_answer, answer_annotations) # This is used for the demo. #output_dict["passage_question_attention"] = passage_question_attention output_dict["passage_question_tokens"] = passage_question_tokens #output_dict["passage_tokens"] = passage_tokens return output_dict
def forward(self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout(self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer(text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = int(math.floor(self._spans_per_word * document_length)) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores(span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 output_dict = {"top_spans": top_spans, "antecedent_indices": valid_antecedent_indices, "predicted_antecedents": predicted_antecedents} if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select(pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log() negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict
def forward( self, # type: ignore text: Dict[str, torch.LongTensor], spans: torch.IntTensor, span_labels: torch.IntTensor = None, keep_antecedent_alternatives: Optional[ScatterableList] = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- text : ``Dict[str, torch.LongTensor]``, required. The output of a ``TextField`` representing the text of the document. spans : ``torch.IntTensor``, required. A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of indices into the text of the document. span_labels : ``torch.IntTensor``, optional (default = None) A tensor of shape (batch_size, num_spans), representing the cluster ids of each span, or -1 for those which do not appear in any clusters. output_alternative_antecedents: ``Optional[ScatterableList]`` - if non-`None` and any contained value is ``True``, the output dictionary will contain antecedent scores and antecedent_mask (see below). Returns ------- An output dictionary consisting of: top_spans : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing the start and end word indices of the top spans that survived the pruning stage. antecedent_indices : ``torch.IntTensor`` A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span the index (with respect to top_spans) of the possible antecedents the model considered. predicted_antecedents : ``torch.IntTensor`` A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the index (with respect to antecedent_indices) of the most likely antecedent. -1 means there was no predicted link. antecedent_scores : ``torch.FloatTensor``, optional A tensor of shape ``(batch_size, num_spans_to_keep, max_antecedents+1)`` giving the antecedent scores for each mention. Each i-th batch element is associated with a matrix whose the j-th row contains the antecedent scores for the j-th mention of that batch (corresponding to top_spans), and k-th column contains the score for the antecedent_indices[k - 1]-th mention being the antecedent of the j-th mention. The first column (index k = 0) contains the score for the j-th mention having no antecedent. antecedent_mask : ``torch.FloatTensor``, optional A tensor of shape ``(batch_size, num_spans_to_keep, max_antecedent)``. The (i, j)-th entry will be 1 if the (i, j)-th entry of `antecedent_scores` gives valid antecedent score and 0 otherwise. This is necessary because, for example, the first mention of a document has no antecedents to score. loss : ``torch.FloatTensor``, optional A scalar loss to be optimised. """ # Shape: (batch_size, document_length, embedding_size) text_embeddings = self._lexical_dropout( self._text_field_embedder(text)) document_length = text_embeddings.size(1) num_spans = spans.size(1) # Shape: (batch_size, document_length) text_mask = util.get_text_field_mask(text).float() # Shape: (batch_size, num_spans) # In order to get the correct shape, we collapse the last dimension (it should only be one # index long). We then reshape it to make sure the shape is correct in edge cases (namely # when there is exactly one input mention). span_mask = (spans[:, :, 0] >= 0).squeeze(-1).reshape((spans.shape[0], spans.shape[1])) \ .float() # SpanFields return -1 when they are used as padding. As we do # some comparisons based on span widths when we attend over the # span representations that we generate from these indices, we # need them to be <= 0. This is only relevant in edge cases where # the number of spans we consider after the pruning stage is >= the # total number of spans, because in this case, it is possible we might # consider a masked span. # Shape: (batch_size, num_spans, 2) spans = F.relu(spans.float()).long() # Shape: (batch_size, document_length, encoding_dim) contextualized_embeddings = self._context_layer( text_embeddings, text_mask) # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size) endpoint_span_embeddings = self._endpoint_span_extractor( contextualized_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size) attended_span_embeddings = self._attentive_span_extractor( text_embeddings, spans) # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size) span_embeddings = torch.cat( [endpoint_span_embeddings, attended_span_embeddings], -1) # Prune based on mention scores. num_spans_to_keep = min( num_spans, int(math.floor(self._spans_per_word * document_length))) (top_span_embeddings, top_span_mask, top_span_indices, top_span_mention_scores) = self._mention_pruner( span_embeddings, span_mask, num_spans_to_keep) top_span_mask = top_span_mask.unsqueeze(-1) # Shape: (batch_size * num_spans_to_keep) # torch.index_select only accepts 1D indices, but here # we need to select spans for each element in the batch. # This reformats the indices to take into account their # index into the batch. We precompute this here to make # the multiple calls to util.batched_index_select below more efficient. flat_top_span_indices = util.flatten_and_batch_shift_indices( top_span_indices, num_spans) # Compute final predictions for which spans to consider as mentions. # Shape: (batch_size, num_spans_to_keep, 2) top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices) # Compute indices for antecedent spans to consider. max_antecedents = min(self._max_antecedents, num_spans_to_keep) # Now that we have our variables in terms of num_spans_to_keep, we need to # compare span pairs to decide each span's antecedent. Each span can only # have prior spans as antecedents, and we only consider up to max_antecedents # prior spans. So the first thing we do is construct a matrix mapping a span's # index to the indices of its allowed antecedents. Note that this is independent # of the batch dimension - it's just a function of the span's position in # top_spans. The spans are in document order, so we can just use the relative # index of the spans to know which other spans are allowed antecedents. # Once we have this matrix, we reformat our variables again to get embeddings # for all valid antecedents for each span. This gives us variables with shapes # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which # we can use to make coreference decisions between valid span pairs. # Shapes: # (num_spans_to_keep, max_antecedents), # (1, max_antecedents), # (1, num_spans_to_keep, max_antecedents) valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \ self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask)) # Select tensors relating to the antecedent spans. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) candidate_antecedent_embeddings = util.flattened_index_select( top_span_embeddings, valid_antecedent_indices) # Shape: (batch_size, num_spans_to_keep, max_antecedents) candidate_antecedent_mention_scores = util.flattened_index_select( top_span_mention_scores, valid_antecedent_indices).squeeze(-1) # Compute antecedent scores. # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size) span_pair_embeddings = self._compute_span_pair_embeddings( top_span_embeddings, candidate_antecedent_embeddings, valid_antecedent_offsets) # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents) coreference_scores = self._compute_coreference_scores( span_pair_embeddings, top_span_mention_scores, candidate_antecedent_mention_scores, valid_antecedent_log_mask) # We now have, for each span which survived the pruning stage, # a predicted antecedent. This implies a clustering if we group # mentions which refer to each other in a chain. # Shape: (batch_size, num_spans_to_keep) _, predicted_antecedents = coreference_scores.max(2) # Subtract one here because index 0 is the "no antecedent" class, # so this makes the indices line up with actual spans if the prediction # is greater than -1. predicted_antecedents -= 1 batch_size = top_spans.shape[0] output_dict = { "top_spans": top_spans, # because antecedent_indices is the same for all batch elements (since it # doesn't depend on the batch content), we need to expand it to have # batch_size as its first dimension or else model.forward_on_instances # will discard it "antecedent_indices": valid_antecedent_indices.expand(batch_size, valid_antecedent_indices.shape[0], valid_antecedent_indices.shape[1]), "predicted_antecedents": predicted_antecedents } if keep_antecedent_alternatives and any(keep_antecedent_alternatives): output_dict["antecedent_scores"] = coreference_scores output_dict["antecedent_mask"] = (valid_antecedent_log_mask >= 0) if span_labels is not None: # Find the gold labels for the spans which we kept. pruned_gold_labels = util.batched_index_select( span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices) antecedent_labels = util.flattened_index_select( pruned_gold_labels, valid_antecedent_indices).squeeze(-1) antecedent_labels += valid_antecedent_log_mask.long() # Compute labels. # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1) gold_antecedent_labels = self._compute_antecedent_gold_labels( pruned_gold_labels, antecedent_labels) # Now, compute the loss using the negative marginal log-likelihood. # This is equal to the log of the sum of the probabilities of all antecedent predictions # that would be consistent with the data, in the sense that we are minimising, for a # given span, the negative marginal log likelihood of all antecedents which are in the # same gold cluster as the span we are currently considering. Each span i predicts a # single antecedent j, but there might be several prior mentions k in the same # coreference cluster that would be valid antecedents. Our loss is the sum of the # probability assigned to all valid antecedents. This is a valid objective for # clustering as we don't mind which antecedent is predicted, so long as they are in # the same coreference cluster. coreference_log_probs = util.masked_log_softmax( coreference_scores, top_span_mask) correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log( ) negative_marginal_log_likelihood = -util.logsumexp( correct_antecedent_log_probs).sum() self._mention_recall(top_spans, metadata) self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata) output_dict["loss"] = negative_marginal_log_likelihood if metadata is not None: output_dict["document"] = [x["original_text"] for x in metadata] return output_dict