def forward(self, tokens: Dict[str, torch.LongTensor], labels: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, torch.Tensor]: embedded_text_input = self.text_field_embedder(tokens) mask = get_text_field_mask(tokens) if self.dropout is not None: embedded_text_input = self.dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask) if self.dropout is not None: encoded_text = self.dropout(encoded_text) if self.feedforward is not None: encoded_text = self.feedforward(encoded_text) logits = self.tag_projection_layer(encoded_text) output = {'logits': logits, 'mask': mask} if labels is not None: flipped_mask = (mask == 0) masked_labels = labels.masked_fill(flipped_mask, -1) output['loss'] = self.loss(logits.transpose(1, 2), masked_labels) for name, metric in self.metrics.items(): metric(logits, labels, mask.float()) return output
def replace_token(target: torch.LongTensor, old: int, new: int): """Replace old tokens with new. Arguments: target old: the token to be replaced by new. new: the token used to replace old. """ return target.masked_fill(target == old, new)
def _update_seq_length_for_generation( sequence_lengths: torch.LongTensor, unfinished_sequences: torch.LongTensor, cur_len: int, is_eos_in_next_token: torch.BoolTensor, ) -> Tuple[torch.LongTensor, torch.LongTensor]: # check if sentence is not finished yet is_sent_unfinished = unfinished_sequences.mul( is_eos_in_next_token.long()).bool() # update sentence length sequence_lengths = sequence_lengths.masked_fill( is_sent_unfinished, cur_len) unfinished_sequences = unfinished_sequences.mul( (~is_eos_in_next_token).long()) return sequence_lengths, unfinished_sequences
def forward(self, tokens: Dict[str, torch.LongTensor], token_lengths: torch.Tensor, target_tokens: torch.LongTensor = None, punct_labels: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, **kwargs) -> Dict[str, torch.Tensor]: mask = get_text_field_mask(tokens) embedded_text_input = self.text_field_embedder(tokens) if self.embedding_dropout is not None: embedded_text_input = self.embedding_dropout(embedded_text_input) encoded_text = self.encoder(embedded_text_input, mask.bool()) if self.encoded_dropout is not None: encoded_text = self.encoded_dropout(encoded_text) if self.feedforward is not None: encoded_text = self.feedforward(encoded_text) punct_logits = self.punct_projection(encoded_text) token_lengths = token_lengths.unsqueeze(-1) encoded_text = torch.cat((encoded_text, token_lengths), dim=-1) output = { 'mask': mask, 'punct_logits': punct_logits, 'embeddings': encoded_text } if target_tokens is not None: output['loss'] = self.__compute_spellchecker_loss( encoded_text, target_tokens) if punct_labels is not None: flipped_mask = (mask == 0) masked_punct_labels = punct_labels.masked_fill(flipped_mask, -1) punct_loss = self.losses['punct'](punct_logits.transpose(1, 2), masked_punct_labels) if 'loss' in output: output['loss'] += punct_loss else: output['loss'] = punct_loss for name, metric in self.metrics.items(): metric(punct_logits, punct_labels, mask.float()) return output
def forward(self, cw_idxs, cc_idxs, qw_idxs, qc_idxs, ids, answer_start_as_passage_spans: torch.LongTensor = None, answer_end_as_passage_spans: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, number_indices = None): batch_size = cw_idxs.size(0) # Forward pass equals to QANet up until last layer spans_start, spans_end = super().forward(cw_idxs, cc_idxs, qw_idxs, qc_idxs) # Modeling layer is used to calculate the vector representation of passage passage_weights = masked_softmax(self.passage_weights_layer(self.passage_aware_rep).squeeze(-1), self.c_mask_c2q, log_softmax = False) passage_vector_rep = passage_weights.unsqueeze(1).bmm(self.passage_aware_rep).squeeze(1) # Modeling layer is use to calculate the vector representation of question question_weights = masked_softmax(self.question_weights_layer(self.qb).squeeze(-1), self.q_mask_c2q, log_softmax = False) question_vector_rep = question_weights.unsqueeze(1).bmm(self.qb).squeeze(1) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = self.answer_ability_predictor( torch.cat([passage_vector_rep, question_vector_rep], -1) ) answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1) # Shape: (batch_size,) best_answer_ability = torch.argmax(answer_ability_log_probs, 1) if "counting" in self.answering_abilities: # Shape: (batch_size, self.max_count) count_number_logits = self.count_number_predictor(passage_vector_rep) count_number_log_probs = torch.nn.functional.log_softmax(count_number_logits, -1) # softmax over possible numbers # Info about the best count number prediction # Shape: (batch_size,) best_count_number = torch.argmax(count_number_log_probs, -1) # most probable numeric value best_count_log_prob = torch.gather( count_number_log_probs, 1, best_count_number.unsqueeze(-1) ).squeeze(-1) if len(self.answering_abilities) > 1: best_count_log_prob += answer_ability_log_probs[:, self.counting_index] # TODO: test or remove if "addition_subtraction" in self.answering_abilities: # M3 (see NAQANet paper) modeled_passage = self.modeled_passage_list[-1] for block in self.modeling_encoder_blocks: modeled_passage = self.dropout_layer( block(modeled_passage, self.c_mask_enc) ) self.modeled_passage_list.append(modeled_passage) encoded_passage_for_numbers = torch.cat( [self.modeled_passage_list[0], self.modeled_passage_list[3]], dim=-1 ) # create mask on indices. Padding value = -1 number_mask = number_indices != -1 clamped_number_indices = number_indices.masked_fill(~number_mask, 0).type(torch.int64).to(self.device) number_mask = number_mask.to(self.device) if number_mask.size(1) > 0: # Shape: (batch_size, max_len_context, 3*hidden_size) encoded_numbers = torch.cat( [ encoded_passage_for_numbers, passage_vector_rep.unsqueeze(1).repeat(1, encoded_passage_for_numbers.size(1), 1), ], -1, ) # Shape: (batch_size, max # number in passages, 3*hidden_size) encoded_numbers = torch.gather(encoded_numbers, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, encoded_numbers.size(-1) )) number_sign_logits = self.number_sign_predictor(encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = best_signs_for_numbers.masked_fill(~number_mask, 0) # Shape: (batch_size, # of numbers in passage) best_signs_log_probs = torch.gather( number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1) ).squeeze(-1) # the probs of the masked positions should be 1 so that it will not affect the joint probability # TODO: this is not quite right, since if there are many numbers in the passage, # TODO: the joint probability would be very small. best_signs_log_probs = best_signs_log_probs.masked_fill(~number_mask, 0) # print(f"best_signs_log_probs 3: {best_signs_log_probs}") # Shape: (batch_size,) best_combination_log_prob = best_signs_log_probs.sum(-1) if len(self.answering_abilities) > 1: best_combination_log_prob += answer_ability_log_probs[ :, self.addition_subtraction_index ] else: print("No numbers in the batch") if "passage_span_extraction" in self.answering_abilities: # Shape: (batch_size, passage_length, modeling_dim * 2)) passage_for_span_start = torch.cat( [self.modeled_passage_list[0], self.modeled_passage_list[1]], dim=-1 ) # Shape: (batch_size, passage_length) passage_span_start_logits = self.passage_span_start_predictor( passage_for_span_start ).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) passage_for_span_end = torch.cat( [self.modeled_passage_list[0], self.modeled_passage_list[2]], dim=-1 ) # Shape: (batch_size, passage_length) passage_span_end_logits = self.passage_span_end_predictor( passage_for_span_end ).squeeze(-1) # Shape: (batch_size, passage_length). Prob on log scale from -infinite to 0 passage_span_start_log_probs = util.masked_log_softmax( passage_span_start_logits, self.c_mask_c2q ) passage_span_end_log_probs = util.masked_log_softmax( passage_span_end_logits, self.c_mask_c2q ) # Info about the best passage span prediction passage_span_start_logits = replace_masked_values_with_big_negative_number( \ passage_span_start_logits, self.c_mask_c2q ) passage_span_end_logits = replace_masked_values_with_big_negative_number( passage_span_end_logits, self.c_mask_c2q ) # Shape: (batch_size, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) # Shape: (batch_size, 2) best_passage_start_log_probs = torch.gather( passage_span_start_log_probs, 1, best_passage_span[:, 0].unsqueeze(-1) ).squeeze(-1) best_passage_end_log_probs = torch.gather( passage_span_end_log_probs, 1, best_passage_span[:, 1].unsqueeze(-1) ).squeeze(-1) # Shape: (batch_size,) best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs if len(self.answering_abilities) > 1: best_passage_span_log_prob += answer_ability_log_probs[ :, self.passage_span_extraction_index ] output_dict = dict() # If answer is given, compute the loss. if ( answer_start_as_passage_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None ): log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_start_as_passage_spans gold_passage_span_ends = answer_end_as_passage_spans # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = gold_passage_span_starts != -1 # start and end should share same mask clamped_gold_passage_span_starts = gold_passage_span_starts. \ masked_fill(~gold_passage_span_mask, 0) clamped_gold_passage_span_ends = gold_passage_span_ends. \ masked_fill(~gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = torch.gather( passage_span_start_log_probs, 1, clamped_gold_passage_span_starts ) log_likelihood_for_passage_span_ends = torch.gather( passage_span_end_log_probs, 1, clamped_gold_passage_span_ends ) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = ( log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends ) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = ( replace_masked_values_with_big_negative_number( log_likelihood_for_passage_spans, gold_passage_span_mask, ) ) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp( log_likelihood_for_passage_spans ) log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span) elif answering_ability == "counting": # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = answer_as_counts != -1 # Shape: (batch_size, # of count answers) clamped_gold_counts = answer_as_counts.masked_fill(~gold_count_mask, 0) log_likelihood_for_counts = torch.gather( count_number_log_probs, 1, clamped_gold_counts ) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = replace_masked_values_with_big_negative_number( log_likelihood_for_counts, gold_count_mask ) # Shape: (batch_size, ) log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts) log_marginal_likelihood_list.append(log_marginal_likelihood_for_count) else: raise ValueError(f"Unsupported answering ability: {answering_ability}") if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = ( all_log_marginal_likelihoods + answer_ability_log_probs ) marginal_log_likelihood = util.logsumexp(all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = -marginal_log_likelihood.mean() if self.eval_data: output_dict["predictions"] = dict() for i in range(batch_size): id = ids[i].item() if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[ best_answer_ability[i].detach().cpu().numpy() ] # print(f"Predicted ability: {predicted_ability_str}") else: predicted_ability_str = self.answering_abilities[0] if predicted_ability_str == "passage_span_extraction": start = best_passage_span[i, 0] end = best_passage_span[i, 1] preds = convert_tokens(self.eval_data, id, start.item(), end.item()) output_dict["predictions"][str(id)] = preds elif predicted_ability_str == "counting": predicted_count = str(best_count_number[i].detach().cpu().numpy()) output_dict["predictions"][str(id)] = predicted_count return output_dict