def test_get_best_span(self): span_begin_probs = torch.FloatTensor([[0.1, 0.3, 0.05, 0.3, 0.25]]).log() span_end_probs = torch.FloatTensor([[0.65, 0.05, 0.2, 0.05, 0.05]]).log() begin_end_idxs = get_best_span(span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]]) # When we were using exclusive span ends, this was an edge case of the dynamic program. # We're keeping the test to make sure we get it right now, after the switch in inclusive # span end. The best answer is (1, 1). span_begin_probs = torch.FloatTensor([[0.4, 0.5, 0.1]]).log() span_end_probs = torch.FloatTensor([[0.3, 0.6, 0.1]]).log() begin_end_idxs = get_best_span(span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 1]]) # Another instance that used to be an edge case. span_begin_probs = torch.FloatTensor([[0.8, 0.1, 0.1]]).log() span_end_probs = torch.FloatTensor([[0.8, 0.1, 0.1]]).log() begin_end_idxs = get_best_span(span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]]) span_begin_probs = torch.FloatTensor([[0.1, 0.2, 0.05, 0.3, 0.25]]).log() span_end_probs = torch.FloatTensor([[0.1, 0.2, 0.5, 0.05, 0.15]]).log() begin_end_idxs = get_best_span(span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 2]])
def _passage_span_module(self, passage_out, passage_mask): # Shape: (batch_size, passage_length) passage_span_start_logits = self._passage_span_start_predictor( passage_out).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_end_logits = self._passage_span_end_predictor( passage_out).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax( passage_span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( passage_span_end_logits, passage_mask) # Info about the best passage span prediction passage_span_start_logits = util.replace_masked_values( passage_span_start_logits, passage_mask, -1e7) passage_span_end_logits = util.replace_masked_values( passage_span_end_logits, passage_mask, -1e7) # Shape: (batch_size, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
def _question_span_module(self, passage_vector, question_out, question_mask): # Shape: (batch_size, question_length) encoded_question_for_span_prediction = \ torch.cat([question_out, passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1) question_span_start_logits = \ self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1) # Shape: (batch_size, question_length) question_span_end_logits = \ self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1) question_span_start_log_probs = util.masked_log_softmax( question_span_start_logits, question_mask) question_span_end_log_probs = util.masked_log_softmax( question_span_end_logits, question_mask) # Info about the best question span prediction question_span_start_logits = \ util.replace_masked_values(question_span_start_logits, question_mask, -1e7) question_span_end_logits = \ util.replace_masked_values(question_span_end_logits, question_mask, -1e7) # Shape: (batch_size, 2) best_question_span = get_best_span(question_span_start_logits, question_span_end_logits) return question_span_start_log_probs, question_span_end_log_probs, best_question_span
def ensemble(subresults: List[Dict[str, torch.Tensor]]) -> torch.Tensor: """ Identifies the best prediction given the results from the submodels. Parameters ---------- subresults : List[Dict[str, torch.Tensor]] Results of each submodel. Returns ------- The index of the best submodel. """ # Choose the highest average confidence span. span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) return get_best_span(span_start_probs.log(), span_end_probs.log()) # type: ignore
def forward( self, passage_attention: torch.Tensor, passage_lengths: List[int], answer_as_passage_spans: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: batch_size, max_passage_length = passage_attention.size() passage_mask = passage_attention.new_zeros(batch_size, max_passage_length) for i, passage_length in enumerate(passage_lengths): passage_mask[i, 0:passage_length] = 1.0 answer_as_passage_spans = answer_as_passage_spans.long() passage_attention = passage_attention * passage_mask if self._scaling: scaled_attentions = [passage_attention * sf for sf in self.scaling_vals] passage_attention_input = torch.stack(scaled_attentions, dim=2) else: passage_attention_input = passage_attention.unsqueeze(2) # Shape: (batch_size, passage_length, span_rnn_hsize) passage_span_logits_repr = self.passage_attention_to_span(passage_attention_input, passage_mask) # Shape: (batch_size, passage_length, 2) passage_span_logits = self.passage_startend_predictor(passage_span_logits_repr) # Shape: (batch_size, passage_length) span_start_logits = passage_span_logits[:, :, 0] span_end_logits = passage_span_logits[:, :, 1] span_start_logits = allenutil.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = allenutil.replace_masked_values(span_end_logits, passage_mask, -1e32) span_start_log_probs = allenutil.masked_log_softmax(span_start_logits, passage_mask) span_end_log_probs = allenutil.masked_log_softmax(span_end_logits, passage_mask) span_start_log_probs = allenutil.replace_masked_values(span_start_log_probs, passage_mask, -1e32) span_end_log_probs = allenutil.replace_masked_values(span_end_log_probs, passage_mask, -1e32) # Loss computation batch_likelihood = 0 output_dict = {} for i in range(batch_size): log_likelihood = self._get_span_answer_log_prob( answer_as_spans=answer_as_passage_spans[i], span_log_probs=(span_start_log_probs[i], span_end_log_probs[i]), ) best_span = get_best_span( span_start_logits=span_start_log_probs[i].unsqueeze(0), span_end_logits=span_end_log_probs[i].unsqueeze(0), ).squeeze(0) correct_start, correct_end = False, False if best_span[0] == answer_as_passage_spans[i][0][0]: self.start_acc_metric(1) correct_start = True else: self.start_acc_metric(0) if best_span[1] == answer_as_passage_spans[i][0][1]: self.end_acc_metric(1) correct_end = True else: self.end_acc_metric(0) if correct_start and correct_end: self.span_acc_metric(1) else: self.span_acc_metric(0) batch_likelihood += log_likelihood loss = -1.0 * batch_likelihood batch_loss = loss / batch_size output_dict["loss"] = batch_loss return output_dict
def _get_best_spans( batch_denotations, batch_denotation_types, question_char_offsets, question_strs, passage_char_offsets, passage_strs, *args, ): """ For all SpanType denotations, get the best span Parameters: ---------- batch_denotations: List[List[Any]] batch_denotation_types: List[List[str]] """ (question_num_tokens, passage_num_tokens, question_mask_aslist, passage_mask_aslist) = args batch_best_spans = [] batch_predicted_answers = [] for instance_idx in range(len(batch_denotations)): instance_prog_denotations = batch_denotations[instance_idx] instance_prog_types = batch_denotation_types[instance_idx] instance_best_spans = [] instance_predicted_ans = [] for denotation, progtype in zip(instance_prog_denotations, instance_prog_types): # if progtype == "QuestionSpanAnswwer": # Distinction between QuestionSpanAnswer and PassageSpanAnswer is not needed currently, # since both classes store the start/end logits as a tuple # Shape: (2, ) best_span = get_best_span( span_start_logits=denotation._value[0].unsqueeze(0), span_end_logits=denotation._value[1].unsqueeze(0), ).squeeze(0) instance_best_spans.append(best_span) predicted_span = tuple(best_span.detach().cpu().numpy()) if progtype == "QuestionSpanAnswer": try: start_offset = question_char_offsets[instance_idx][predicted_span[0]][0] end_offset = question_char_offsets[instance_idx][predicted_span[1]][1] predicted_answer = question_strs[instance_idx][start_offset:end_offset] except: print() print(f"PredictedSpan: {predicted_span}") print(f"Question numtoksn: {question_num_tokens[instance_idx]}") print(f"QuesMaskLen: {question_mask_aslist[instance_idx].size()}") print(f"StartLogProbs:{denotation._value[0]}") print(f"EndLogProbs:{denotation._value[1]}") print(f"LenofOffsets: {len(question_char_offsets[instance_idx])}") print(f"QuesStrLen: {len(question_strs[instance_idx])}") elif progtype == "PassageSpanAnswer": try: start_offset = passage_char_offsets[instance_idx][predicted_span[0]][0] end_offset = passage_char_offsets[instance_idx][predicted_span[1]][1] predicted_answer = passage_strs[instance_idx][start_offset:end_offset] except: print() print(f"PredictedSpan: {predicted_span}") print(f"Passagenumtoksn: {passage_num_tokens[instance_idx]}") print(f"PassageMaskLen: {passage_mask_aslist[instance_idx].size()}") print(f"LenofOffsets: {len(passage_char_offsets[instance_idx])}") print(f"PassageStrLen: {len(passage_strs[instance_idx])}") else: raise NotImplementedError instance_predicted_ans.append(predicted_answer) batch_best_spans.append(instance_best_spans) batch_predicted_answers.append(instance_predicted_ans) return batch_best_spans, batch_predicted_answers
def forward(self, # type: ignore passage_question: Dict[str, torch.LongTensor], # passage: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, answer_type = None, # answer_as_passage_spans: torch.LongTensor = None, answer_as_spans: torch.LongTensor = None, answer_as_add_sub_expressions: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: passage_question_mask = passage_question["mask"].float() embedded_passage_question = self._dropout(self._text_field_embedder(passage_question))#把id转化为vector,加dropout batch_size = embedded_passage_question.size(0) #bzw加的 encoded_passage_question = embedded_passage_question """ passage_vactor 用 [CLS]对应的代替 """ passage_question_vector = encoded_passage_question[:,0] if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = \ self._answer_ability_predictor(passage_question_vector) answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1) #best_answer_ability = torch.argmax(answer_ability_log_probs, 1) if "counting" in self.answering_abilities: # Shape: (batch_size, 10) count_number_logits = self._count_number_predictor(passage_question_vector) count_number_log_probs = torch.nn.functional.log_softmax(count_number_logits, -1) # Info about the best count number prediction # Shape: (batch_size,) best_count_number = torch.argmax(count_number_log_probs, -1) best_count_log_prob = \ torch.gather(count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1) if len(self.answering_abilities) > 1: best_count_log_prob += answer_ability_log_probs[:, self._counting_index] if "span_extraction" in self.answering_abilities: # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(encoded_passage_question).squeeze(-1) # Shape: (batch_size, passage_length) span_end_logits = self._span_end_predictor(encoded_passage_question).squeeze(-1) # Shape: (batch_size, passage_length) span_start_log_probs = util.masked_log_softmax(span_start_logits, passage_question_mask) span_end_log_probs = util.masked_log_softmax(span_end_logits, passage_question_mask) # Info about the best passage span prediction span_start_logits = util.replace_masked_values(span_start_logits, passage_question_mask, -1e7)#把mask的结果用-1e7代替 span_end_logits = util.replace_masked_values(span_end_logits, passage_question_mask, -1e7) # Shape: (batch_size, 2) best_span = get_best_span(span_start_logits, span_end_logits) # Shape: (batch_size, 2) best_start_log_probs = \ torch.gather(span_start_log_probs, 1, best_span[:, 0].unsqueeze(-1)).squeeze(-1) best_end_log_probs = \ torch.gather(span_end_log_probs, 1, best_span[:, 1].unsqueeze(-1)).squeeze(-1) # Shape: (batch_size,) best_span_log_prob = best_start_log_probs + best_end_log_probs if len(self.answering_abilities) > 1: best_span_log_prob += answer_ability_log_probs[:, self._span_extraction_index] if "addition_subtraction" in self.answering_abilities: # Shape: (batch_size, # of numbers in the passage) number_indices = number_indices.squeeze(-1) number_mask = (number_indices != -1).long() clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0) #encoded_passage_for_numbers = torch.cat([modeled_passage_list[0], modeled_passage_list[3]], dim=-1) # Shape: (batch_size, # of numbers in the passage, encoding_dim) encoded_numbers = torch.gather( encoded_passage_question, 1, clamped_number_indices.unsqueeze(-1).expand(-1, -1, encoded_passage_question.size(-1))) #self._external_number_embedding = self._external_number_embedding.cuda(device) #encoded_numbers = self.self_attention(encoded_numbers,number_mask) encoded_numbers = self.Concat_attention(encoded_numbers,passage_question_vector,number_mask) # Shape: (batch_size, # of numbers in the passage) #encoded_numbers = torch.cat( # [encoded_numbers, passage_question_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor(encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values(best_signs_for_numbers, number_mask, 0) # Shape: (batch_size, # of numbers in passage) best_signs_log_probs = torch.gather( number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1) # the probs of the masked positions should be 1 so that it will not affect the joint probability # TODO: this is not quite right, since if there are many numbers in the passage, # TODO: the joint probability would be very small. best_signs_log_probs = util.replace_masked_values(best_signs_log_probs, number_mask, 0) # Shape: (batch_size,) if len(self.answering_abilities) > 1: # batch_size best_combination_log_prob = best_signs_log_probs.sum(-1) best_combination_log_prob += answer_ability_log_probs[:, self._addition_subtraction_index] best_answer_ability = torch.argmax(torch.stack([best_span_log_prob,best_combination_log_prob,best_count_log_prob],-1),1) output_dict = {} # If answer is given, compute the loss. if answer_as_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None: log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "span_extraction": # Shape: (batch_size, # of answer spans) gold_span_starts = answer_as_spans[:, :, 0] gold_span_ends = answer_as_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_span_mask = (gold_span_starts != -1).long() clamped_gold_span_starts = \ util.replace_masked_values(gold_span_starts, gold_span_mask, 0) clamped_gold_span_ends = \ util.replace_masked_values(gold_span_ends, gold_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_span_starts = \ torch.gather(span_start_log_probs, 1, clamped_gold_span_starts) log_likelihood_for_span_ends = \ torch.gather(span_end_log_probs, 1, clamped_gold_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_spans = \ log_likelihood_for_span_starts + log_likelihood_for_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_spans = \ util.replace_masked_values(log_likelihood_for_spans, gold_span_mask, -1e7) # Shape: (batch_size, ) # log_marginal_likelihood_for_span = torch.sum(log_likelihood_for_spans,-1) log_marginal_likelihood_for_span = util.logsumexp(log_likelihood_for_spans) log_marginal_likelihood_list.append(log_marginal_likelihood_for_span) elif answering_ability == "addition_subtraction": # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here. # Shape: (batch_size, # of combinations) gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1) > 0).float() # Shape: (batch_size, # of numbers in the passage, # of combinations) gold_add_sub_signs = answer_as_add_sub_expressions.transpose(1, 2) # Shape: (batch_size, # of numbers in the passage, # of combinations) log_likelihood_for_number_signs = torch.gather(number_sign_log_probs, 2, gold_add_sub_signs) # the log likelihood of the masked positions should be 0 # so that it will not affect the joint probability log_likelihood_for_number_signs = \ util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0) # Shape: (batch_size, # of combinations) log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(1) # For those padded combinations, we set their log probabilities to be very small negative value log_likelihood_for_add_subs = \ util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7) # Shape: (batch_size, ) #log_marginal_likelihood_for_add_sub = torch.sum(log_likelihood_for_add_subs,-1) #log_marginal_likelihood_for_add_sub = util.logsumexp(log_likelihood_for_add_subs) #log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub) log_marginal_likelihood_for_add_sub = util.logsumexp(log_likelihood_for_add_subs) #log_marginal_likelihood_for_external = util.logsumexp(log_likelihood_for_externals) log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub) elif answering_ability == "counting": # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = (answer_as_counts != -1).long() # Shape: (batch_size, # of count answers) clamped_gold_counts = util.replace_masked_values(answer_as_counts, gold_count_mask, 0) log_likelihood_for_counts = torch.gather(count_number_log_probs, 1, clamped_gold_counts) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = \ util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7) # Shape: (batch_size, ) #log_marginal_likelihood_for_count = torch.sum(log_likelihood_for_counts,-1) log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts) log_marginal_likelihood_list.append(log_marginal_likelihood_for_count) else: raise ValueError(f"Unsupported answering ability: {answering_ability}") if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1) loss_for_type = -(torch.sum(answer_ability_log_probs*answer_type,-1)).mean() loss_for_answer = -(torch.sum(all_log_marginal_likelihoods,-1)).mean() loss = loss_for_type+loss_for_answer else: marginal_log_likelihood = log_marginal_likelihood_list[0] loss = - marginal_log_likelihood.mean() output_dict["loss"] = loss # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] passage_question_tokens = [] for i in range(batch_size): passage_question_tokens.append(metadata[i]['passage_question_tokens']) if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[best_answer_ability[i].detach().cpu().numpy()] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} # We did not consider multi-mention answers here if predicted_ability_str == "span_extraction": answer_json["answer_type"] = "span" passage_question_token = metadata[i]['passage_question_tokens'] #offsets = metadata[i]['passage_token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = predicted_span[0] end_offset = predicted_span[1] predicted_answer = " ".join([token for token in passage_question_token[start_offset:end_offset+1] if token!="[SEP]"]).replace(" ##","") answer_json["value"] = predicted_answer answer_json["spans"] = [(start_offset, end_offset)] elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" predicted_count = best_count_number[i].detach().cpu().numpy() predicted_answer = str(predicted_count) answer_json["count"] = predicted_count elif predicted_ability_str == "addition_subtraction": answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]['original_numbers'] sign_remap = {0: 0, 1: 1, 2: -1} predicted_signs = [sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy()] result=0 for j,number in enumerate(original_numbers): sign = predicted_signs[j] if sign!=0: result += sign * number predicted_answer = str(result) #offsets = metadata[i]['passage_token_offsets'] number_indices = metadata[i]['number_indices'] #number_positions = [offsets[index] for index in number_indices] answer_json['numbers'] = [] for indice, value, sign in zip(number_indices, original_numbers, predicted_signs): answer_json['numbers'].append({'span': indice, 'value': str(value), 'sign': sign}) if number_indices[-1] == -1: # There is a dummy 0 number at position -1 added in some cases; we are # removing that here. answer_json["numbers"].pop() answer_json["value"] = str(result) else: raise ValueError(f"Unsupported answer ability: {predicted_ability_str}") output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(answer_json) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(predicted_answer, answer_annotations) # This is used for the demo. #output_dict["passage_question_attention"] = passage_question_attention output_dict["passage_question_tokens"] = passage_question_tokens #output_dict["passage_tokens"] = passage_tokens return output_dict
def forward(self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], passageidx2numberidx: torch.LongTensor, passage_number_values: List[int], passageidx2dateidx: torch.LongTensor, passage_date_values: List[List[Date]], actions: List[List[ProductionRule]], datecomp_ques_event_date_groundings: List[Tuple[List[int], List[int]]] = None, numcomp_qspan_num_groundings: List[Tuple[List[int], List[int]]] = None, strongly_supervised: List[bool] = None, qtypes: List[str] = None, qattn_supervision: torch.FloatTensor = None, answer_types: List[str] = None, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, epoch_num: List[int] = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: batch_size = len(actions) if epoch_num is not None: # epoch_num in allennlp starts from 0 epoch = epoch_num[0] + 1 else: epoch = None question_mask = allenutil.get_text_field_mask(question).float() passage_mask = allenutil.get_text_field_mask(passage).float() embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage)) projected_embedded_question = self._encoding_proj_layer(embedded_question) projected_embedded_passage = self._encoding_proj_layer(embedded_passage) encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask)) encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = allenutil.masked_softmax( passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = allenutil.weighted_sum(encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = allenutil.masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = allenutil.weighted_sum(encoded_passage, attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors], dim=-1) ) modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)] for _ in range(3): modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask)) modeled_passage_list.append(modeled_passage) # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = allenutil.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = allenutil.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) span_start_logprob = allenutil.masked_log_softmax(span_start_logits, mask=passage_mask, dim=-1) span_end_logprob = allenutil.masked_log_softmax(span_end_logits, mask=passage_mask, dim=-1) span_start_logprob = allenutil.replace_masked_values(span_start_logprob, passage_mask, -1e32) span_end_logprob = allenutil.replace_masked_values(span_end_logprob, passage_mask, -1e32) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } if answer_types is not None: loss = 0 for i in range(batch_size): loss += self._get_span_answer_log_prob(answer_as_spans=answer_as_passage_spans[i], span_log_probs=(span_start_logprob[i], span_end_logprob[i])) loss = (-1.0 * loss) / batch_size self.modelloss_metric(myutils.tocpuNPList(loss)[0]) output_dict["loss"] = loss if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['passage_token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_annotations = metadata[i].get('answer_annotation') self._drop_metrics(best_span_string, [answer_annotations]) output_dict.update({'metadata': metadata}) return output_dict
def forward( self, # type: ignore bert_input: Dict[str, torch.LongTensor], sim_bert_input: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ if self.use_scenario_encoding: # Shape: (batch_size, sim_bert_input_len_wp) sim_bert_input_token_labels_wp = sim_bert_input[ 'scenario_gold_encoding'] # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim) sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input) # Shape: (batch_size, sim_bert_input_len_wp) sim_input_mask_wp = (sim_bert_input['bert'] != 0).float() # Shape: (batch_size, sim_bert_input_len_wp) sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[ 'bert-type-ids'].float() # works only with one [SEP] # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim) sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze( 2) # Shape: (batch_size, passage_len_wp, embedding_dim) sim_passage_representation_wp = sim_passage_representation_wp[:, sim_passage_mask_wp .sum( dim =0 ) > 0, :] # Shape: (batch_size, passage_len_wp) sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:, sim_passage_mask_wp .sum( dim =0 ) > 0] # Shape: (batch_size, passage_len_wp) sim_passage_mask_wp = sim_passage_mask_wp[:, sim_passage_mask_wp.sum( dim=0) > 0] # Shape: (batch_size, passage_len_wp, 4) sim_token_logits_wp = self._sim_token_label_predictor( sim_passage_representation_wp) if span_start is not None: # during training and validation class_weights = torch.tensor(self.sim_class_weights, device=sim_token_logits_wp.device, dtype=torch.float) sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4), sim_passage_token_labels_wp.view(-1), ignore_index=0, weight=class_weights) self._sim_loss_metric(sim_loss.item()) self._sim_yes_f1(sim_token_logits_wp, sim_passage_token_labels_wp, sim_passage_mask_wp) self._sim_no_f1(sim_token_logits_wp, sim_passage_token_labels_wp, sim_passage_mask_wp) if self.sim_pretraining: return {'loss': sim_loss} if not self.sim_pretraining: # Shape: (batch_size, passage_len_wp) bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax( dim=2)) * sim_passage_mask_wp.long() # Shape: (batch_size, bert_input_len_wp) bert_input_wp_len = bert_input['history_encoding'].size(1) if bert_input['scenario_encoding'].size(1) > bert_input_wp_len: # Shape: (batch_size, bert_input_len_wp) bert_input['scenario_encoding'] = bert_input[ 'scenario_encoding'][:, :bert_input_wp_len] else: batch_size = bert_input['scenario_encoding'].size(0) difference = bert_input_wp_len - bert_input[ 'scenario_encoding'].size(1) zeros = torch.zeros( batch_size, difference, dtype=bert_input['scenario_encoding'].dtype, device=bert_input['scenario_encoding'].device) # Shape: (batch_size, bert_input_len_wp) bert_input['scenario_encoding'] = torch.cat( [bert_input['scenario_encoding'], zeros], dim=1) # Shape: (batch_size, bert_input_len + 1, embedding_dim) bert_output = self._text_field_embedder(bert_input) # Shape: (batch_size, embedding_dim) pooled_output = bert_output[:, 0] # Shape: (batch_size, bert_input_len, embedding_dim) bert_output = bert_output[:, 1:, :] # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len) passage_representation, passage_mask = self.get_passage_representation( bert_output, bert_input) # Shape: (batch_size, 4) action_logits = self._action_predictor(pooled_output) # Shape: (batch_size, passage_len, 2) span_logits = self._span_predictor(passage_representation) # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1) span_start_logits, span_end_logits = span_logits.split(1, dim=2) # Shape: (batch_size, passage_len) span_start_logits = span_start_logits.squeeze(2) # Shape: (batch_size, passage_len) span_end_logits = span_end_logits.squeeze(2) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "pooled_output": pooled_output, "passage_representation": passage_representation, "action_logits": action_logits, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } if self.use_scenario_encoding: output_dict["sim_token_logits"] = sim_token_logits_wp # Compute the loss for training (and for validation) if span_start is not None: # Shape: (batch_size,) span_loss = nll_loss(util.masked_log_softmax( span_start_logits, passage_mask), span_start.squeeze(1), reduction='none') # Shape: (batch_size,) span_loss += nll_loss(util.masked_log_softmax( span_end_logits, passage_mask), span_end.squeeze(1), reduction='none') # Shape: (batch_size,) more_mask = (label == self.vocab.get_token_index( 'More', namespace="labels")).float() # Shape: (batch_size,) span_loss = (span_loss * more_mask).sum() / (more_mask.sum() + 1e-6) if more_mask.sum() > 1e-7: self._span_start_accuracy(span_start_logits, span_start.squeeze(1), more_mask) self._span_end_accuracy(span_end_logits, span_end.squeeze(1), more_mask) # Shape: (batch_size, 2) span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long() self._span_accuracy(best_span, torch.cat([span_start, span_end], dim=1), span_acc_mask) action_loss = cross_entropy(action_logits, label) self._action_accuracy(action_logits, label) self._span_loss_metric(span_loss.item()) self._action_loss_metric(action_loss.item()) output_dict['loss'] = self.loss_weights[ 'span_loss'] * span_loss + self.loss_weights[ 'action_loss'] * action_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if not self.training: # true during validation and test output_dict['best_span_str'] = [] batch_size = len(metadata) for i in range(batch_size): passage_text = metadata[i]['passage_text'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_str = passage_text[start_offset:end_offset] output_dict['best_span_str'].append(best_span_str) if 'gold_span' in metadata[i]: if metadata[i]['action'] == 'More': gold_span = metadata[i]['gold_span'] self._squad_metrics(best_span_str, [gold_span]) return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], numbers_in_passage: Dict[str, torch.LongTensor], number_indices: torch.LongTensor, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_add_sub_expressions: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ, unused-argument question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer(embedded_question) projected_embedded_passage = self._encoding_proj_layer(embedded_passage) encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask)) encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax(passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax(passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) passsage_attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum(encoded_passage, passsage_attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors], dim=-1)) # The recurrent modeling layers. Since these layers share the same parameters, # we don't construct them conditioned on answering abilities. modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)] for _ in range(4): modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask)) modeled_passage_list.append(modeled_passage) # Pop the first one, which is input modeled_passage_list.pop(0) # The first modeling layer is used to calculate the vector representation of passage passage_weights = self._passage_weights_predictor(modeled_passage_list[0]).squeeze(-1) passage_weights = masked_softmax(passage_weights, passage_mask) passage_vector = util.weighted_sum(modeled_passage_list[0], passage_weights) # The vector representation of question is calculated based on the unmatched encoding, # because we may want to infer the answer ability only based on the question words. question_weights = self._question_weights_predictor(encoded_question).squeeze(-1) question_weights = masked_softmax(question_weights, question_mask) question_vector = util.weighted_sum(encoded_question, question_weights) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = \ self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1)) answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1) best_answer_ability = torch.argmax(answer_ability_log_probs, 1) if "counting" in self.answering_abilities: # Shape: (batch_size, 10) count_number_logits = self._count_number_predictor(passage_vector) count_number_log_probs = torch.nn.functional.log_softmax(count_number_logits, -1) # Info about the best count number prediction # Shape: (batch_size,) best_count_number = torch.argmax(count_number_log_probs, -1) best_count_log_prob = \ torch.gather(count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1) if len(self.answering_abilities) > 1: best_count_log_prob += answer_ability_log_probs[:, self._counting_index] if "passage_span_extraction" in self.answering_abilities: # Shape: (batch_size, passage_length, modeling_dim * 2)) passage_for_span_start = torch.cat([modeled_passage_list[0], modeled_passage_list[1]], dim=-1) # Shape: (batch_size, passage_length) passage_span_start_logits = self._passage_span_start_predictor(passage_for_span_start).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) passage_for_span_end = torch.cat([modeled_passage_list[0], modeled_passage_list[2]], dim=-1) # Shape: (batch_size, passage_length) passage_span_end_logits = self._passage_span_end_predictor(passage_for_span_end).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax(passage_span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax(passage_span_end_logits, passage_mask) # Info about the best passage span prediction passage_span_start_logits = util.replace_masked_values(passage_span_start_logits, passage_mask, -1e7) passage_span_end_logits = util.replace_masked_values(passage_span_end_logits, passage_mask, -1e7) # Shape: (batch_size, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) # Shape: (batch_size, 2) best_passage_start_log_probs = \ torch.gather(passage_span_start_log_probs, 1, best_passage_span[:, 0].unsqueeze(-1)).squeeze(-1) best_passage_end_log_probs = \ torch.gather(passage_span_end_log_probs, 1, best_passage_span[:, 1].unsqueeze(-1)).squeeze(-1) # Shape: (batch_size,) best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs if len(self.answering_abilities) > 1: best_passage_span_log_prob += answer_ability_log_probs[:, self._passage_span_extraction_index] if "question_span_extraction" in self.answering_abilities: # Shape: (batch_size, question_length) encoded_question_for_span_prediction = \ torch.cat([encoded_question, passage_vector.unsqueeze(1).repeat(1, encoded_question.size(1), 1)], -1) question_span_start_logits = \ self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1) # Shape: (batch_size, question_length) question_span_end_logits = \ self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1) question_span_start_log_probs = util.masked_log_softmax(question_span_start_logits, question_mask) question_span_end_log_probs = util.masked_log_softmax(question_span_end_logits, question_mask) # Info about the best question span prediction question_span_start_logits = \ util.replace_masked_values(question_span_start_logits, question_mask, -1e7) question_span_end_logits = \ util.replace_masked_values(question_span_end_logits, question_mask, -1e7) # Shape: (batch_size, 2) best_question_span = get_best_span(question_span_start_logits, question_span_end_logits) # Shape: (batch_size, 2) best_question_start_log_probs = \ torch.gather(question_span_start_log_probs, 1, best_question_span[:, 0].unsqueeze(-1)).squeeze(-1) best_question_end_log_probs = \ torch.gather(question_span_end_log_probs, 1, best_question_span[:, 1].unsqueeze(-1)).squeeze(-1) # Shape: (batch_size,) best_question_span_log_prob = best_question_start_log_probs + best_question_end_log_probs if len(self.answering_abilities) > 1: best_question_span_log_prob += answer_ability_log_probs[:, self._question_span_extraction_index] if "addition_subtraction" in self.answering_abilities: # Shape: (batch_size, # of numbers in the passage) number_indices = number_indices.squeeze(-1) number_mask = (number_indices != -1).long() clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0) encoded_passage_for_numbers = torch.cat([modeled_passage_list[0], modeled_passage_list[3]], dim=-1) # Shape: (batch_size, # of numbers in the passage, encoding_dim) encoded_numbers = torch.gather( encoded_passage_for_numbers, 1, clamped_number_indices.unsqueeze(-1).expand(-1, -1, encoded_passage_for_numbers.size(-1))) # Shape: (batch_size, # of numbers in the passage) encoded_numbers = torch.cat( [encoded_numbers, passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor(encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values(best_signs_for_numbers, number_mask, 0) # Shape: (batch_size, # of numbers in passage) best_signs_log_probs = torch.gather( number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1) # the probs of the masked positions should be 1 so that it will not affect the joint probability # TODO: this is not quite right, since if there are many numbers in the passage, # TODO: the joint probability would be very small. best_signs_log_probs = util.replace_masked_values(best_signs_log_probs, number_mask, 0) # Shape: (batch_size,) best_combination_log_prob = best_signs_log_probs.sum(-1) if len(self.answering_abilities) > 1: best_combination_log_prob += answer_ability_log_probs[:, self._addition_subtraction_index] output_dict = {} # If answer is given, compute the loss. if answer_as_passage_spans is not None or answer_as_question_spans is not None \ or answer_as_add_sub_expressions is not None or answer_as_counts is not None: log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_as_passage_spans[:, :, 0] gold_passage_span_ends = answer_as_passage_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = (gold_passage_span_starts != -1).long() clamped_gold_passage_span_starts = \ util.replace_masked_values(gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = \ util.replace_masked_values(gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = \ torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_passage_span_ends = \ torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = \ log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = \ util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp(log_likelihood_for_passage_spans) log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span) elif answering_ability == "question_span_extraction": # Shape: (batch_size, # of answer spans) gold_question_span_starts = answer_as_question_spans[:, :, 0] gold_question_span_ends = answer_as_question_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_question_span_mask = (gold_question_span_starts != -1).long() clamped_gold_question_span_starts = \ util.replace_masked_values(gold_question_span_starts, gold_question_span_mask, 0) clamped_gold_question_span_ends = \ util.replace_masked_values(gold_question_span_ends, gold_question_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_span_starts = \ torch.gather(question_span_start_log_probs, 1, clamped_gold_question_span_starts) log_likelihood_for_question_span_ends = \ torch.gather(question_span_end_log_probs, 1, clamped_gold_question_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_spans = \ log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_question_spans = \ util.replace_masked_values(log_likelihood_for_question_spans, gold_question_span_mask, -1e7) # Shape: (batch_size, ) # pylint: disable=invalid-name log_marginal_likelihood_for_question_span = \ util.logsumexp(log_likelihood_for_question_spans) log_marginal_likelihood_list.append(log_marginal_likelihood_for_question_span) elif answering_ability == "addition_subtraction": # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here. # Shape: (batch_size, # of combinations) gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1) > 0).float() # Shape: (batch_size, # of numbers in the passage, # of combinations) gold_add_sub_signs = answer_as_add_sub_expressions.transpose(1, 2) # Shape: (batch_size, # of numbers in the passage, # of combinations) log_likelihood_for_number_signs = torch.gather(number_sign_log_probs, 2, gold_add_sub_signs) # the log likelihood of the masked positions should be 0 # so that it will not affect the joint probability log_likelihood_for_number_signs = \ util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0) # Shape: (batch_size, # of combinations) log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(1) # For those padded combinations, we set their log probabilities to be very small negative value log_likelihood_for_add_subs = \ util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_add_sub = util.logsumexp(log_likelihood_for_add_subs) log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub) elif answering_ability == "counting": # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = (answer_as_counts != -1).long() # Shape: (batch_size, # of count answers) clamped_gold_counts = util.replace_masked_values(answer_as_counts, gold_count_mask, 0) log_likelihood_for_counts = torch.gather(count_number_log_probs, 1, clamped_gold_counts) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = \ util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts) log_marginal_likelihood_list.append(log_marginal_likelihood_for_count) else: raise ValueError(f"Unsupported answering ability: {answering_ability}") if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs marginal_log_likelihood = util.logsumexp(all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = - marginal_log_likelihood.mean() # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[best_answer_ability[i].detach().cpu().numpy()] else: predicted_ability_str = self.answering_abilities[0] # We did not consider multi-mention answers here if predicted_ability_str == "passage_span_extraction": passage_str = metadata[i]['original_passage'] offsets = metadata[i]['passage_token_offsets'] predicted_span = tuple(best_passage_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] predicted_answer = passage_str[start_offset:end_offset] elif predicted_ability_str == "question_span_extraction": question_str = metadata[i]['original_question'] offsets = metadata[i]['question_token_offsets'] predicted_span = tuple(best_question_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] predicted_answer = question_str[start_offset:end_offset] elif predicted_ability_str == "addition_subtraction": # plus_minus combination answer original_numbers = metadata[i]['original_numbers'] sign_remap = {0: 0, 1: 1, 2: -1} predicted_signs = [sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy()] result = sum([sign * number for sign, number in zip(predicted_signs, original_numbers)]) predicted_answer = str(result) elif predicted_ability_str == "counting": predicted_count = best_count_number[i].detach().cpu().numpy() predicted_answer = str(predicted_count) else: raise ValueError(f"Unsupported answer ability: {predicted_ability_str}") output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(predicted_answer) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(predicted_answer, answer_annotations) # This is used for the demo. output_dict["passage_question_attention"] = passage_question_attention output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens # The demo takes `best_span_str` as a key to find the predicted answer output_dict["best_span_str"] = output_dict["answer"] return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() batch_size, passage_len = passage_mask.shape span_start_logits = torch.FloatTensor(batch_size, passage_len) span_start_logits.zero_() span_start_logits.scatter_(1, span_start, 1) span_end_logits = torch.FloatTensor(batch_size, passage_len) span_end_logits.zero_() span_end_logits.scatter_(1, span_end, 1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] # question_tokens = [] # passage_tokens = [] all_reference_answers_text = [] all_best_spans = [] for i in range(batch_size): # question_tokens.append(metadata[i]['question_tokens']) # passage_tokens.append(metadata[i]['passage_tokens']) predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_span = predicted_span[0] end_span = predicted_span[1] best_span_tokens = metadata[i]['passage_tokens'][ start_span:end_span + 1] best_span_string = " ".join(best_span_tokens) output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._metrics(best_span_string, answer_texts) all_best_spans.append(best_span_string) all_reference_answers_text.append(answer_texts) if not self.training: self.calculate_rouge(all_best_spans, all_reference_answers_text) # output_dict['question_tokens'] = question_tokens # output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, passage_sem_views_q: torch.IntTensor = None, passage_sem_views_k: torch.IntTensor = None, question_sem_views_q: torch.IntTensor = None, question_sem_views_k: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. passage_sem_views_q : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Query (Q) passage_sem_views_k : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Key (K) question_sem_views_q : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Query (Q) question_sem_views_k : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Key (K) metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ return_output_metadata = self.return_output_metadata question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() batch_size, passage_len = passage_mask.shape # convert to long if passage_sem_views_q is not None: passage_sem_views_q = passage_sem_views_q.long() if passage_sem_views_k is not None: passage_sem_views_k = passage_sem_views_k.long() if question_sem_views_q is not None: question_sem_views_q = question_sem_views_q.long() if question_sem_views_k is not None: question_sem_views_k = question_sem_views_k.long() span_start_logits = torch.FloatTensor(batch_size, passage_len) span_start_logits.zero_() span_start_logits.scatter_(1, span_start, 1) span_end_logits = torch.FloatTensor(batch_size, passage_len) span_end_logits.zero_() span_end_logits.scatter_(1, span_end, 1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] metrics_per_item = None all_reference_answers_text = [] all_best_spans = [] return_metrics_per_item = True if not self.training: metrics_per_item = [{} for x in range(batch_size)] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_span = predicted_span[0] end_span = predicted_span[1] best_span_tokens = metadata[i]['passage_tokens'][ start_span:end_span + 1] best_span_string = " ".join(best_span_tokens) output_dict['best_span_str'].append(best_span_string) output_dict['best_span_tokens'] = best_span_tokens answer_texts = metadata[i].get('answer_texts', []) if return_output_metadata: best_span_semantic_features = [] curr_item_features = passage_sem_views_q[i] for view_id in range(curr_item_features.shape[0]): curr_view_feats = curr_item_features[view_id][ start_span:end_span + 1] best_span_semantic_features.append( curr_view_feats.tolist()) output_dict[ 'best_span_semantic_features'] = best_span_semantic_features all_best_spans.append(best_span_string) if answer_texts: curr_item_em, curr_item_f1 = self._squad_metrics( best_span_string, answer_texts, return_score=True) if not self.training and return_metrics_per_item: metrics_per_item[i]["em"] = curr_item_em metrics_per_item[i]["f1"] = curr_item_f1 all_reference_answers_text.append(answer_texts) # output metadata if return_output_metadata: output_dict["output_metadata"] = { "modeling_layer": { "modeling_layer_iter_000": { "encoder_block_001": { "semantic_views_q": passage_sem_views_q, "semantic_views_sent_mask": passage_sem_views_k, }, } } } if not self.training and len(all_reference_answers_text) > 0: metrics_per_item_rouge = self.calculate_rouge( all_best_spans, all_reference_answers_text, return_metrics_per_item=return_metrics_per_item) for i, curr_metrics in enumerate(metrics_per_item_rouge): metrics_per_item[i].update(curr_metrics) if metrics_per_item is not None: output_dict['metrics'] = metrics_per_item output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ, unused-argument passage_mask = util.get_text_field_mask(passage).float() embedded_passage = self._dropout(self._text_field_embedder(passage)) batch_size = embedded_passage.size(0) embedded_passage = self._highway_layer( self._embedding_proj_layer(embedded_passage)) encoded_passage_list = [embedded_passage] for _ in range(3): encoded_passage = self._dropout( self._encoding_layer(encoded_passage_list[-1], passage_mask)) encoded_passage_list.append(encoded_passage) # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat( [encoded_passage_list[-3], encoded_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat( [encoded_passage_list[-3], encoded_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = {} # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(best_span_string) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(best_span_string, answer_annotations) return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, passage_sem_views_q: torch.IntTensor = None, passage_sem_views_k: torch.IntTensor = None, question_sem_views_q: torch.IntTensor = None, question_sem_views_k: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. passage_sem_views_q : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Query (Q) passage_sem_views_k : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Key (K) question_sem_views_q : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Query (Q) question_sem_views_k : ``torch.IntTensor``, optional Paragraph semantic views features for multihead attention Key (K) metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ return_output_metadata = self.return_output_metadata question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() if isinstance(self._phrase_layer, QaNetSemanticFlatEncoder) \ or isinstance(self._phrase_layer, QaNetSemanticFlatConcatEncoder)\ or isinstance(self._modeling_layer, QaNetSemanticFlatEncoder) \ or isinstance(self._modeling_layer, QaNetSemanticFlatConcatEncoder): if passage_sem_views_q is not None: passage_sem_views_q = passage_sem_views_q.long() if passage_sem_views_k is not None: passage_sem_views_k = passage_sem_views_k.long() if question_sem_views_q is not None: question_sem_views_q = question_sem_views_q.long() if question_sem_views_k is not None: question_sem_views_k = question_sem_views_k.long() if torch.cuda.is_available(): # indices question_mask = to_cuda(question_mask, move_to_cuda=True) passage_mask = to_cuda(passage_mask, move_to_cuda=True) question = { k: to_cuda(v, move_to_cuda=True) for k, v in question.items() } passage = { k: to_cuda(v, move_to_cuda=True) for k, v in passage.items() } # span if span_start is not None: span_start = to_cuda(span_start, move_to_cuda=True) if span_end is not None: span_end = to_cuda(span_end, move_to_cuda=True) # semantic views if passage_sem_views_q is not None: passage_sem_views_q = to_cuda(passage_sem_views_q, move_to_cuda=True) if passage_sem_views_k is not None: passage_sem_views_k = to_cuda(passage_sem_views_k, move_to_cuda=True) if question_sem_views_q is not None: question_sem_views_q = to_cuda(question_sem_views_q, move_to_cuda=True) if question_sem_views_k is not None: question_sem_views_k = to_cuda(question_sem_views_k, move_to_cuda=True) embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer( self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer( self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer( embedded_question) projected_embedded_passage = self._encoding_proj_layer( embedded_passage) encoded_passage_output_metadata = None encoded_question_output_metadata = None if isinstance(self._phrase_layer, QaNetSemanticFlatEncoder) \ or isinstance(self._phrase_layer, QaNetSemanticFlatConcatEncoder): if is_output_meta_supported(self._phrase_layer): encoded_passage, encoded_passage_output_metadata = self._phrase_layer( projected_embedded_passage, passage_sem_views_q, passage_sem_views_k, passage_mask, return_output_metadata) encoded_passage = self._dropout(encoded_passage) encoded_question, encoded_question_output_metadata = self._phrase_layer( projected_embedded_question, question_sem_views_q, question_sem_views_k, question_mask, return_output_metadata) encoded_question = self._dropout(encoded_question) else: encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_sem_views_q, passage_sem_views_k, passage_mask)) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_sem_views_q, question_sem_views_k, question_mask)) else: encoded_passage = self._dropout( self._phrase_layer(projected_embedded_passage, passage_mask)) encoded_question = self._dropout( self._phrase_layer(projected_embedded_question, question_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat([ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors ], dim=-1)) modeled_passage_list = [ self._modeling_proj_layer(merged_passage_attention_vectors) ] modeled_passage_output_metadata_list = {} for modeling_layer_id in range(3): modeled_passage_output_metadata = None if isinstance(self._modeling_layer, QaNetSemanticFlatEncoder) \ or isinstance(self._modeling_layer, QaNetSemanticFlatConcatEncoder): if is_output_meta_supported(self._modeling_layer): modeled_passage, modeled_passage_output_metadata = self._modeling_layer( modeled_passage_list[-1], passage_sem_views_q, passage_sem_views_k, passage_mask, return_output_metadata) else: modeled_passage = self._modeling_layer( modeled_passage_list[-1], passage_sem_views_q, passage_sem_views_k, passage_mask) else: modeled_passage = self._modeling_layer( modeled_passage_list[-1], passage_mask) modeled_passage = self._dropout(modeled_passage) modeled_passage_list.append(modeled_passage) modeled_passage_output_metadata_list[ "modeling_layer_iter_{0:03d}".format( modeling_layer_id)] = modeled_passage_output_metadata # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat( [modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat( [modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { #"passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: try: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) output_dict["loss"] = loss except Exception as e: logging.exception(e) # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] metrics_per_item = None all_reference_answers_text = [] all_best_spans = [] return_metrics_per_item = True if not self.training: metrics_per_item = [{} for x in range(batch_size)] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) # offsets = metadata[i]['token_offsets'] # start_offset = offsets[predicted_span[0]][0] # end_offset = offsets[predicted_span[1]][1] start_span = predicted_span[0] end_span = predicted_span[1] best_span_tokens = metadata[i]['passage_tokens'][ start_span:end_span + 1] best_span_string = " ".join(best_span_tokens) output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) all_best_spans.append(best_span_string) if answer_texts: curr_item_em, curr_item_f1 = self._squad_metrics( best_span_string, answer_texts, return_score=True) if not self.training and return_metrics_per_item: metrics_per_item[i]["em"] = curr_item_em metrics_per_item[i]["f1"] = curr_item_f1 all_reference_answers_text.append(answer_texts) if return_output_metadata: output_dict["output_metadata"] = { "encoded_passage": encoded_passage_output_metadata, "encoded_question": encoded_question_output_metadata, "modeling_layer": modeled_passage_output_metadata_list, } if not self.training and len(all_reference_answers_text) > 0: metrics_per_item_rouge = self.calculate_rouge( all_best_spans, all_reference_answers_text, return_metrics_per_item=return_metrics_per_item) for i, curr_metrics in enumerate(metrics_per_item_rouge): metrics_per_item[i].update(curr_metrics) if metrics_per_item is not None: output_dict['metrics'] = metrics_per_item output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward( self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], question_and_passage: Dict[str, torch.LongTensor], answer_as_passage_spans: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ, unused-argument # logger.info("="*10) # logger.info([len(metadata[i]["passage_tokens"]) for i in range(len(metadata))]) # logger.info([len(metadata[i]["question_tokens"]) for i in range(len(metadata))]) # logger.info(question_and_passage["bert"].shape) # The segment labels should be as following: # <CLS> + question_word_pieces + <SEP> + passage_word_pieces + <SEP> # 0 0 0 1 1 # We get this in a tricky way here expanded_question_bert_tensor = torch.zeros_like( question_and_passage["bert"]) expanded_question_bert_tensor[:, :question["bert"]. shape[1]] = question["bert"] segment_labels = (question_and_passage["bert"] - expanded_question_bert_tensor > 0).long() question_and_passage["segment_labels"] = segment_labels embedded_question_and_passage = self._text_field_embedder( question_and_passage) # We also get the passage mask for the concatenated question and passage in a similar way expanded_question_mask = torch.zeros_like(question_and_passage["mask"]) # We shift the 1s to one column right here, to mask the [SEP] token in the middle expanded_question_mask[:, 1:question["mask"].shape[1] + 1] = question["mask"] expanded_question_mask[:, 0] = 1 passage_mask = question_and_passage["mask"] - expanded_question_mask batch_size = embedded_question_and_passage.size(0) span_start_logits = self._span_start_predictor( embedded_question_and_passage).squeeze(-1) span_end_logits = self._span_end_predictor( embedded_question_and_passage).squeeze(-1) # Shape: (batch_size, passage_length) passage_span_start_log_probs = util.masked_log_softmax( span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( span_end_logits, passage_mask) passage_span_start_logits = util.replace_masked_values( span_start_logits, passage_mask, -1e32) passage_span_end_logits = util.replace_masked_values( span_end_logits, passage_mask, -1e32) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) output_dict = { "passage_span_start_probs": passage_span_start_log_probs.exp(), "passage_span_end_probs": passage_span_end_log_probs.exp() } # If answer is given, compute the loss for training. if answer_as_passage_spans is not None: # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_as_passage_spans[:, :, 0] gold_passage_span_ends = answer_as_passage_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = (gold_passage_span_starts != -1).long() clamped_gold_passage_span_starts = util.replace_masked_values( gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = util.replace_masked_values( gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = \ torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_passage_span_ends = \ torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = \ log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = \ util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e32) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp( log_likelihood_for_passage_spans) output_dict[ "loss"] = -log_marginal_likelihood_for_passage_span.mean() # Compute the metrics and add the tokenized input to the output. if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) # We did not consider multi-mention answers here passage_str = metadata[i]['original_passage'] offsets = metadata[i]['passage_token_offsets'] predicted_span = tuple( best_passage_span[i].detach().cpu().numpy()) # Remove the offsets of question tokens and the [SEP] token predicted_span = (predicted_span[0] - len(metadata[i]['question_tokens']) - 1, predicted_span[1] - len(metadata[i]['question_tokens']) - 1) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_answer_str = passage_str[start_offset:end_offset] output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(best_answer_str) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(best_answer_str, answer_annotations) return output_dict
def forward( # type: ignore self, question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None, ) -> Dict[str, torch.Tensor]: """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ embedded_question = self._highway_layer( self._text_field_embedder(question)) embedded_passage = self._highway_layer( self._text_field_embedder(passage)) batch_size = embedded_question.size(0) passage_length = embedded_passage.size(1) question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() question_lstm_mask = question_mask if self._mask_lstms else None passage_lstm_mask = passage_mask if self._mask_lstms else None encoded_question = self._dropout( self._phrase_layer(embedded_question, question_lstm_mask)) encoded_passage = self._dropout( self._phrase_layer(embedded_passage, passage_lstm_mask)) encoding_dim = encoded_question.size(-1) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention( encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = util.masked_softmax( passage_question_similarity, question_mask) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum( encoded_question, passage_question_attention) # We replace masked values with something really negative here, so they don't affect the # max below. masked_similarity = util.replace_masked_values( passage_question_similarity, question_mask.unsqueeze(1), -1e7) # Shape: (batch_size, passage_length) question_passage_similarity = masked_similarity.max( dim=-1)[0].squeeze(-1) # Shape: (batch_size, passage_length) question_passage_attention = util.masked_softmax( question_passage_similarity, passage_mask) # Shape: (batch_size, encoding_dim) question_passage_vector = util.weighted_sum( encoded_passage, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) tiled_question_passage_vector = question_passage_vector.unsqueeze( 1).expand(batch_size, passage_length, encoding_dim) # Shape: (batch_size, passage_length, encoding_dim * 4) final_merged_passage = torch.cat( [ encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * tiled_question_passage_vector, ], dim=-1, ) modeled_passage = self._dropout( self._modeling_layer(final_merged_passage, passage_lstm_mask)) modeling_dim = modeled_passage.size(-1) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) span_start_input = self._dropout( torch.cat([final_merged_passage, modeled_passage], dim=-1)) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor( span_start_input).squeeze(-1) # Shape: (batch_size, passage_length) span_start_probs = util.masked_softmax(span_start_logits, passage_mask) # Shape: (batch_size, modeling_dim) span_start_representation = util.weighted_sum(modeled_passage, span_start_probs) # Shape: (batch_size, passage_length, modeling_dim) tiled_start_representation = span_start_representation.unsqueeze( 1).expand(batch_size, passage_length, modeling_dim) # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3) span_end_representation = torch.cat( [ final_merged_passage, modeled_passage, tiled_start_representation, modeled_passage * tiled_start_representation, ], dim=-1, ) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end = self._dropout( self._span_end_encoder(span_end_representation, passage_lstm_mask)) # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim) span_end_input = self._dropout( torch.cat([final_merged_passage, encoded_span_end], dim=-1)) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss( util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss( util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.cat([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict["best_span_str"] = [] question_tokens = [] passage_tokens = [] token_offsets = [] for i in range(batch_size): question_tokens.append(metadata[i]["question_tokens"]) passage_tokens.append(metadata[i]["passage_tokens"]) token_offsets.append(metadata[i]["token_offsets"]) passage_str = metadata[i]["original_passage"] offsets = metadata[i]["token_offsets"] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict["best_span_str"].append(best_span_string) answer_texts = metadata[i].get("answer_texts", []) if answer_texts: self._squad_metrics(best_span_string, answer_texts) output_dict["question_tokens"] = question_tokens output_dict["passage_tokens"] = passage_tokens output_dict["token_offsets"] = token_offsets return output_dict
def forward(self, # type: ignore metadata: Dict, tokens: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None ) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- tokens : Dict[str, torch.LongTensor] From a ``TextField`` (that has a bert-pretrained token indexer) span_start : torch.IntTensor, optional (default = None) A tensor of shape (batch_size, 1) which contains the start_position of the answer in the passage, or 0 if impossible. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : torch.IntTensor, optional (default = None) A tensor of shape (batch_size, 1) which contains the end_position of the answer in the passage, or 0 if impossible. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. Returns ------- An output dictionary consisting of: logits : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing unnormalized log probabilities of the label. start_probs: torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. end_probs : torch.FloatTensor A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label. best_span: loss : torch.FloatTensor, optional A scalar loss to be optimised. """ input_ids = tokens[self._index] token_type_ids = tokens[f"{self._index}-type-ids"] input_mask = (input_ids != 0).long() # 1. Build model here bert_output, _ = self.bert_model(input_ids, token_type_ids, attention_mask=input_mask) linear_output = self.linear(bert_output) linear_dropped = self.drop(linear_output) start_logits, end_logits = linear_dropped.split(1, dim=-1) start_logits, end_logits = start_logits.squeeze(-1), end_logits.squeeze(-1) # 2. Compute start_position and end_position and then get the best span # using allennlp.models.reading_comprehension.util.get_best_span() masked_soft_start = masked_softmax(start_logits, mask=input_mask) masked_soft_end = masked_softmax(end_logits, mask=input_mask) best_span = get_best_span(masked_soft_start, masked_soft_end) output_dict = { "start_logits": start_logits, "end_logits": end_logits, "start_probs": masked_soft_start, "end_probs": masked_soft_end, "best_span": best_span } # 4. Compute loss and accuracies. You should compute at least: # span_start accuracy, span_end accuracy and full span accuracy. # import ipdb;ipdb.set_trace() self._span_start_accuracy(start_logits, span_start.squeeze()) self._span_end_accuracy(end_logits, span_end.squeeze()) self._span_accuracy(best_span, torch.stack([span_start.squeeze(), span_end.squeeze()])) # UNCOMMENT THIS LINE # import ipdb;ipdb.set_trace() if span_start is not None: ignored_index = start_logits.size(1) span_start.clamp_(0, ignored_index) span_end.clamp_(0, ignored_index) start_loss = self.loss(start_logits, span_start.squeeze(-1)) end_loss = self.loss(end_logits, span_end.squeeze(-1)) combined_loss = (start_loss + end_loss) / 2 output_dict["loss"] = combined_loss # 5. Optionally you can compute the official squad metrics (exact match, f1). # Instantiate the metric object in __init__ using allennlp.training.metrics.SquadEmAndF1() # When you call it, you need to give it the word tokens of the span (implement and call decode() below) # and the gold tokens found in metadata[i]['answer_texts'] return output_dict
def forward(self, # type: ignore question: Dict[str, torch.LongTensor], passage: Dict[str, torch.LongTensor], span_start: torch.IntTensor = None, span_end: torch.IntTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # pylint: disable=arguments-differ """ Parameters ---------- question : Dict[str, torch.LongTensor] From a ``TextField``. passage : Dict[str, torch.LongTensor] From a ``TextField``. The model assumes that this passage contains the answer to the question, and predicts the beginning and ending positions of the answer within the passage. span_start : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the beginning position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. span_end : ``torch.IntTensor``, optional From an ``IndexField``. This is one of the things we are trying to predict - the ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional If present, this should contain the question tokens, passage tokens, original passage text, and token offsets into the passage for each instance in the batch. The length of this list should be the batch size, and each dictionary should have the keys ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- An output dictionary consisting of: span_start_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span start position. span_start_probs : torch.FloatTensor The result of ``softmax(span_start_logits)``. span_end_logits : torch.FloatTensor A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log probabilities of the span end position (inclusive). span_end_probs : torch.FloatTensor The result of ``softmax(span_end_logits)``. best_span : torch.IntTensor The result of a constrained inference over ``span_start_logits`` and ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` and each offset is a token index. loss : torch.FloatTensor, optional A scalar loss to be optimised. best_span_str : List[str] If sufficient metadata was provided for the instances in the batch, we also return the string from the original passage that the model thinks is the best answer to the question. """ question_mask = util.get_text_field_mask(question).float() passage_mask = util.get_text_field_mask(passage).float() embedded_question = self._dropout(self._text_field_embedder(question)) embedded_passage = self._dropout(self._text_field_embedder(passage)) embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question)) embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage)) batch_size = embedded_question.size(0) projected_embedded_question = self._encoding_proj_layer(embedded_question) projected_embedded_passage = self._encoding_proj_layer(embedded_passage) encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask)) encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask)) # Shape: (batch_size, passage_length, question_length) passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) # Shape: (batch_size, passage_length, question_length) passage_question_attention = masked_softmax( passage_question_similarity, question_mask, memory_efficient=True) # Shape: (batch_size, passage_length, encoding_dim) passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) # Shape: (batch_size, question_length, passage_length) question_passage_attention = masked_softmax( passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True) # Shape: (batch_size, passage_length, passage_length) attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) # Shape: (batch_size, passage_length, encoding_dim) passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention) # Shape: (batch_size, passage_length, encoding_dim * 4) merged_passage_attention_vectors = self._dropout( torch.cat([encoded_passage, passage_question_vectors, encoded_passage * passage_question_vectors, encoded_passage * passage_passage_vectors], dim=-1) ) modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)] for _ in range(3): modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask)) modeled_passage_list.append(modeled_passage) # Shape: (batch_size, passage_length, modeling_dim * 2)) span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) # Shape: (batch_size, passage_length) span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) # Shape: (batch_size, passage_length) span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, "span_start_logits": span_start_logits, "span_start_probs": span_start_probs, "span_end_logits": span_end_logits, "span_end_probs": span_end_probs, "best_span": best_span, } # Compute the loss for training. if span_start is not None: loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) self._span_accuracy(best_span, torch.cat([span_start, span_end], -1)) output_dict["loss"] = loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] question_tokens = [] passage_tokens = [] for i in range(batch_size): question_tokens.append(metadata[i]['question_tokens']) passage_tokens.append(metadata[i]['passage_tokens']) passage_str = metadata[i]['original_passage'] offsets = metadata[i]['token_offsets'] predicted_span = tuple(best_span[i].detach().cpu().numpy()) start_offset = offsets[predicted_span[0]][0] end_offset = offsets[predicted_span[1]][1] best_span_string = passage_str[start_offset:end_offset] output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) if answer_texts: self._metrics(best_span_string, answer_texts) output_dict['question_tokens'] = question_tokens output_dict['passage_tokens'] = passage_tokens return output_dict
def forward(self, tokens: Dict[str, torch.LongTensor], segment_ids: torch.LongTensor = None, start_positions: torch.LongTensor = None, end_positions: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> torch.Tensor: self._debug -= 1 input_ids = tokens['tokens'] batch_size = input_ids.size(0) num_choices = input_ids.size(1) tokens_mask = (input_ids != self._padding_value).long() if self._debug > 0: print(f"batch_size = {batch_size}") print(f"num_choices = {num_choices}") print(f"tokens_mask = {tokens_mask}") print(f"input_ids.size() = {input_ids.size()}") print(f"input_ids = {input_ids}") print(f"segment_ids = {segment_ids}") print(f"start_positions = {start_positions}") print(f"end_positions = {end_positions}") # Segment ids are not used by RoBERTa transformer_outputs = self._transformer_model( input_ids=input_ids, # token_type_ids=segment_ids, attention_mask=tokens_mask) sequence_output = transformer_outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) span_start_logits = util.replace_masked_values(start_logits, tokens_mask, -1e7) span_end_logits = util.replace_masked_values(end_logits, tokens_mask, -1e7) best_span = get_best_span(span_start_logits, span_end_logits) span_start_probs = util.masked_softmax(span_start_logits, tokens_mask) span_end_probs = util.masked_softmax(span_end_logits, tokens_mask) output_dict = { "start_logits": start_logits, "end_logits": end_logits, "best_span": best_span } output_dict["start_probs"] = span_start_probs output_dict["end_probs"] = span_end_probs if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions.clamp_(0, ignored_index) end_positions.clamp_(0, ignored_index) self._span_start_accuracy(span_start_logits, start_positions) self._span_end_accuracy(span_end_logits, end_positions) self._span_accuracy( best_span, torch.cat([ start_positions.unsqueeze(-1), end_positions.unsqueeze(-1) ], -1)) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index) # Should we mask out invalid positions here? start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 output_dict["loss"] = total_loss # Compute the EM and F1 on SQuAD and add the tokenized input to the output. if metadata is not None: output_dict['best_span_str'] = [] output_dict['exact_match'] = [] output_dict['f1_score'] = [] tokens_texts = [] for i in range(batch_size): tokens_text = metadata[i]['tokens'] tokens_texts.append(tokens_text) predicted_span = tuple(best_span[i].detach().cpu().numpy()) predicted_start = predicted_span[0] predicted_end = predicted_span[1] predicted_tokens = tokens_text[predicted_start:(predicted_end + 1)] best_span_string = self.convert_tokens_to_string( predicted_tokens) output_dict['best_span_str'].append(best_span_string) answer_texts = metadata[i].get('answer_texts', []) exact_match = 0 f1_score = 0 if answer_texts: exact_match, f1_score = self._squad_metrics( best_span_string, answer_texts) output_dict['exact_match'].append(exact_match) output_dict['f1_score'].append(f1_score) output_dict['tokens_texts'] = tokens_texts if self._debug > 0: print(f"output_dict = {output_dict}") return output_dict