def take_step( self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # get the relevant scores for the time step class_log_probabilities = state['log_probs'][:, state['step_num'][0], :] is_wordpiece = ( 1 - state['wordpiece_mask'][:, state['step_num'][0]]).byte() # mask illegal BIO transitions transitions_mask = torch.cat( (torch.ones_like(class_log_probabilities[:, :3]), torch.zeros_like(class_log_probabilities[:, -2:])), dim=-1).byte() transitions_mask[:, 2] &= ((last_predictions == 1) | (last_predictions == 2)) transitions_mask[:, 1:3] &= ((class_log_probabilities[:, :3] == 0.0).sum(-1) != 3).unsqueeze(-1).repeat( 1, 2) # assuming the wordpiece mask doesn't intersect with the other masks (pad, cls/sep) transitions_mask[:, 2] |= is_wordpiece & ((last_predictions == 1) | (last_predictions == 2)) class_log_probabilities = replace_masked_values( class_log_probabilities, transitions_mask, -1e7) state['step_num'] = state['step_num'].clone() + 1 return class_log_probabilities, state
def _get_combined_likelihood(self, answer_as_list_of_bios, log_probs): # answer_as_list_of_bios - Shape: (batch_size, # of correct sequences, seq_length) # log_probs - Shape: (batch_size, seq_length, 3) # Shape: (batch_size, # of correct sequences, seq_length, 3) # duplicate log_probs for each gold bios sequence expanded_log_probs = log_probs.unsqueeze(1).expand( -1, answer_as_list_of_bios.size()[1], -1, -1) # get the log-likelihood per each sequence index # Shape: (batch_size, # of correct sequences, seq_length) log_likelihoods = \ torch.gather(expanded_log_probs, dim=-1, index=answer_as_list_of_bios.unsqueeze(-1)).squeeze(-1) # Shape: (batch_size, # of correct sequences) correct_sequences_pad_mask = (answer_as_list_of_bios.sum(-1) > 0).long() # Sum the log-likelihoods for each index to get the log-likelihood of the sequence # Shape: (batch_size, # of correct sequences) sequences_log_likelihoods = log_likelihoods.sum(dim=-1) sequences_log_likelihoods = replace_masked_values( sequences_log_likelihoods, correct_sequences_pad_mask, -1e7) # Sum the log-likelihoods for each sequence to get the marginalized log-likelihood over the correct answers log_marginal_likelihood = logsumexp(sequences_log_likelihoods, dim=-1) return log_marginal_likelihood
def module(self, bert_out, seq_mask=None): logits = self.predictor(bert_out) if self._use_crf: # The mask should not be applied here when using CRF, but should be passed ot the CRF log_probs = torch.nn.functional.log_softmax(logits, dim=-1) else: if seq_mask is not None: log_probs = replace_masked_values( torch.nn.functional.log_softmax(logits, dim=-1), seq_mask.unsqueeze(-1), 0.0) logits = replace_masked_values(logits, seq_mask.unsqueeze(-1), -1e7) else: log_probs = torch.nn.functional.log_softmax(logits) return log_probs, logits
def log_likelihood(self, gold_labels, log_probs, seq_mask, is_bio_mask, **kwargs): # we only want the log probabilities of the gold labels # what we get is: # log_likelihoods_for_multispan[i,j] == log_probs[i,j, gold_labels[i,j]] log_likelihoods_for_multispan = \ torch.gather(log_probs, dim=-1, index=gold_labels.unsqueeze(-1)).squeeze(-1) # Our marginal likelihood is the sum of all the gold label likelihoods, ignoring the # padding tokens. log_likelihoods_for_multispan = \ replace_masked_values(log_likelihoods_for_multispan, seq_mask, 0.0) log_marginal_likelihood_for_multispan = log_likelihoods_for_multispan.sum( dim=-1) # For questions without spans, we set their log probabilities to be very small negative value log_marginal_likelihood_for_multispan = \ replace_masked_values(log_marginal_likelihood_for_multispan, is_bio_mask, -1e7) return log_marginal_likelihood_for_multispan
def prediction(self, log_probs, logits, qp_tokens, p_text, q_text, seq_mask, wordpiece_mask, use_beam_search): if use_beam_search: top_k_predictions = self._get_top_k_sequences( log_probs.unsqueeze(0), wordpiece_mask.unsqueeze(0), self._prediction_beam_size) predicted_tags = top_k_predictions[0, 0, :] else: predicted_tags = torch.argmax(logits, dim=-1) predicted_tags = replace_masked_values(predicted_tags, seq_mask, 0) return MultiSpanHead.decode_spans_from_tags(predicted_tags, qp_tokens, p_text, q_text)
def log_likelihood(self, gold_labels, log_probs, seq_mask, is_bio_mask, **kwargs): logits = kwargs['logits'] if gold_labels is not None: log_denominator = self.crf._input_likelihood(logits, seq_mask) log_numerator = self.crf._joint_likelihood(logits, gold_labels, seq_mask) log_likelihood = log_numerator - log_denominator log_likelihood = replace_masked_values(log_likelihood, is_bio_mask, -1e7) return log_likelihood
def forward( self, # type: ignore input_ids: torch.LongTensor, input_mask: torch.LongTensor, input_segments: torch.LongTensor, passage_mask: torch.LongTensor, question_mask: torch.LongTensor, number_indices: torch.LongTensor, passage_number_order: torch.LongTensor, question_number_order: torch.LongTensor, question_number_indices: torch.LongTensor, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_add_sub_expressions: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, span_num: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # sequence_output, _, other_sequence_output = self.bert(input_ids, input_segments, input_mask) outputs = self.bert(input_ids, attention_mask=input_mask, token_type_ids=input_segments) sequence_output = outputs[0] sequence_output_list = [item for item in outputs[2][-4:]] batch_size = input_ids.size(0) if ("passage_span_extraction" in self.answering_abilities or "question_span" in self.answering_abilities) and self.use_gcn: # M2, M3 sequence_alg = self._gcn_input_proj( torch.cat([sequence_output_list[2], sequence_output_list[3]], dim=2)) encoded_passage_for_numbers = sequence_alg encoded_question_for_numbers = sequence_alg # passage number extraction real_number_indices = number_indices - 1 number_mask = (real_number_indices > -1).long() # ?? clamped_number_indices = util.replace_masked_values( real_number_indices, number_mask, 0) encoded_numbers = torch.gather( encoded_passage_for_numbers, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, encoded_passage_for_numbers.size(-1))) # question number extraction question_number_mask = (question_number_indices > -1).long() clamped_question_number_indices = util.replace_masked_values( question_number_indices, question_number_mask, 0) question_encoded_number = torch.gather( encoded_question_for_numbers, 1, clamped_question_number_indices.unsqueeze(-1).expand( -1, -1, encoded_question_for_numbers.size(-1))) # graph mask number_order = torch.cat( (passage_number_order, question_number_order), -1) new_graph_mask = number_order.unsqueeze(1).expand( batch_size, number_order.size(-1), -1) > number_order.unsqueeze(-1).expand( batch_size, -1, number_order.size(-1)) new_graph_mask = new_graph_mask.long() all_number_mask = torch.cat((number_mask, question_number_mask), dim=-1) new_graph_mask = all_number_mask.unsqueeze( 1) * all_number_mask.unsqueeze(-1) * new_graph_mask # iteration d_node, q_node, d_node_weight, _ = self._gcn( d_node=encoded_numbers, q_node=question_encoded_number, d_node_mask=number_mask, q_node_mask=question_number_mask, graph=new_graph_mask) gcn_info_vec = torch.zeros((batch_size, sequence_alg.size(1) + 1, sequence_output_list[-1].size(-1)), dtype=torch.float, device=d_node.device) clamped_number_indices = util.replace_masked_values( real_number_indices, number_mask, gcn_info_vec.size(1) - 1) gcn_info_vec.scatter_( 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, d_node.size(-1)), d_node) gcn_info_vec = gcn_info_vec[:, :-1, :] sequence_output_list[2] = self._gcn_enc( self._proj_ln(sequence_output_list[2] + gcn_info_vec)) sequence_output_list[0] = self._gcn_enc( self._proj_ln0(sequence_output_list[0] + gcn_info_vec)) sequence_output_list[1] = self._gcn_enc( self._proj_ln1(sequence_output_list[1] + gcn_info_vec)) sequence_output_list[3] = self._gcn_enc( self._proj_ln3(sequence_output_list[3] + gcn_info_vec)) # passage hidden and question hidden sequence_h2_weight = self._proj_sequence_h( sequence_output_list[2]).squeeze(-1) passage_h2_weight = util.masked_softmax(sequence_h2_weight, passage_mask) passage_h2 = util.weighted_sum(sequence_output_list[2], passage_h2_weight) question_h2_weight = util.masked_softmax(sequence_h2_weight, question_mask) question_h2 = util.weighted_sum(sequence_output_list[2], question_h2_weight) # passage g0, g1, g2 question_g0_weight = self._proj_sequence_g0( sequence_output_list[0]).squeeze(-1) question_g0_weight = util.masked_softmax(question_g0_weight, question_mask) question_g0 = util.weighted_sum(sequence_output_list[0], question_g0_weight) question_g1_weight = self._proj_sequence_g1( sequence_output_list[1]).squeeze(-1) question_g1_weight = util.masked_softmax(question_g1_weight, question_mask) question_g1 = util.weighted_sum(sequence_output_list[1], question_g1_weight) question_g2_weight = self._proj_sequence_g2( sequence_output_list[2]).squeeze(-1) question_g2_weight = util.masked_softmax(question_g2_weight, question_mask) question_g2 = util.weighted_sum(sequence_output_list[2], question_g2_weight) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = self._answer_ability_predictor( torch.cat([passage_h2, question_h2, sequence_output[:, 0]], 1)) answer_ability_log_probs = F.log_softmax(answer_ability_logits, -1) best_answer_ability = torch.argmax(answer_ability_log_probs, 1) real_number_indices = number_indices.squeeze(-1) - 1 number_mask = (real_number_indices > -1).long() clamped_number_indices = util.replace_masked_values( real_number_indices, number_mask, 0) encoded_passage_for_numbers = torch.cat( [sequence_output_list[2], sequence_output_list[3]], dim=-1) encoded_numbers = torch.gather( encoded_passage_for_numbers, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, encoded_passage_for_numbers.size(-1))) number_weight = self._proj_number(encoded_numbers).squeeze(-1) number_mask = (number_indices > -1).long() number_weight = util.masked_softmax(number_weight, number_mask) number_vector = util.weighted_sum(encoded_numbers, number_weight) if "counting" in self.answering_abilities: # Shape: (batch_size, 10) count_number_logits = self._count_number_predictor( torch.cat([ number_vector, passage_h2, question_h2, sequence_output[:, 0] ], dim=1)) count_number_log_probs = torch.nn.functional.log_softmax( count_number_logits, -1) # Info about the best count number prediction # Shape: (batch_size,) best_count_number = torch.argmax(count_number_log_probs, -1) best_count_log_prob = torch.gather( count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1) if len(self.answering_abilities) > 1: best_count_log_prob += answer_ability_log_probs[:, self. _counting_index] if "passage_span_extraction" in self.answering_abilities or "question_span_extraction" in self.answering_abilities: # start 0, 2 sequence_for_span_start = torch.cat([ sequence_output_list[2], sequence_output_list[0], sequence_output_list[2] * question_g2.unsqueeze(1), sequence_output_list[0] * question_g0.unsqueeze(1) ], dim=2) sequence_span_start_logits = self._span_start_predictor( sequence_for_span_start).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) sequence_for_span_end = torch.cat([ sequence_output_list[2], sequence_output_list[1], sequence_output_list[2] * question_g2.unsqueeze(1), sequence_output_list[1] * question_g1.unsqueeze(1) ], dim=2) # Shape: (batch_size, passage_length) sequence_span_end_logits = self._span_end_predictor( sequence_for_span_end).squeeze(-1) # Shape: (batch_size, passage_length) # span number prediction span_num_logits = self._proj_span_num( torch.cat([passage_h2, question_h2, sequence_output[:, 0]], dim=1)) span_num_log_probs = torch.nn.functional.log_softmax( span_num_logits, -1) best_span_number = torch.argmax(span_num_log_probs, dim=-1) if "passage_span_extraction" in self.answering_abilities: passage_span_start_log_probs = util.masked_log_softmax( sequence_span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( sequence_span_end_logits, passage_mask) # Info about the best passage span prediction passage_span_start_logits = util.replace_masked_values( sequence_span_start_logits, passage_mask, -1e7) passage_span_end_logits = util.replace_masked_values( sequence_span_end_logits, passage_mask, -1e7) # Shage: (batch_size, topk, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) if "question_span_extraction" in self.answering_abilities: question_span_start_log_probs = util.masked_log_softmax( sequence_span_start_logits, question_mask) question_span_end_log_probs = util.masked_log_softmax( sequence_span_end_logits, question_mask) # Info about the best question span prediction question_span_start_logits = util.replace_masked_values( sequence_span_start_logits, question_mask, -1e7) question_span_end_logits = util.replace_masked_values( sequence_span_end_logits, question_mask, -1e7) # Shape: (batch_size, topk, 2) best_question_span = get_best_span(question_span_start_logits, question_span_end_logits) if "addition_subtraction" in self.answering_abilities: alg_encoded_numbers = torch.cat([ encoded_numbers, question_h2.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1), passage_h2.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1), sequence_output[:, 0].unsqueeze(1).repeat( 1, encoded_numbers.size(1), 1) ], 2) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor( alg_encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax( number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values( best_signs_for_numbers, number_mask, 0) # Shape: (batch_size, # of numbers in passage) best_signs_log_probs = torch.gather( number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1) # the probs of the masked positions should be 1 so that it will not affect the joint probability # TODO: this is not quite right, since if there are many numbers in the passage, # TODO: the joint probability would be very small. best_signs_log_probs = util.replace_masked_values( best_signs_log_probs, number_mask, 0) # Shape: (batch_size,) best_combination_log_prob = best_signs_log_probs.sum(-1) if len(self.answering_abilities) > 1: best_combination_log_prob += answer_ability_log_probs[:, self. _addition_subtraction_index] output_dict = {} # If answer is given, compute the loss. if answer_as_passage_spans is not None or answer_as_question_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None: log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_as_passage_spans[:, :, 0] gold_passage_span_ends = answer_as_passage_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = (gold_passage_span_starts != -1).long() clamped_gold_passage_span_starts = util.replace_masked_values( gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = util.replace_masked_values( gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = torch.gather( passage_span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_passage_span_ends = torch.gather( passage_span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = util.replace_masked_values( log_likelihood_for_passage_spans, gold_passage_span_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp( log_likelihood_for_passage_spans) # span log probabilities log_likelihood_for_passage_span_nums = torch.gather( span_num_log_probs, 1, span_num) log_likelihood_for_passage_span_nums = util.replace_masked_values( log_likelihood_for_passage_span_nums, gold_passage_span_mask[:, :1], -1e7) log_marginal_likelihood_for_passage_span_nums = util.logsumexp( log_likelihood_for_passage_span_nums) log_marginal_likelihood_list.append( log_marginal_likelihood_for_passage_span + log_marginal_likelihood_for_passage_span_nums) elif answering_ability == "question_span_extraction": # Shape: (batch_size, # of answer spans) gold_question_span_starts = answer_as_question_spans[:, :, 0] gold_question_span_ends = answer_as_question_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_question_span_mask = (gold_question_span_starts != -1).long() clamped_gold_question_span_starts = util.replace_masked_values( gold_question_span_starts, gold_question_span_mask, 0) clamped_gold_question_span_ends = util.replace_masked_values( gold_question_span_ends, gold_question_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_span_starts = torch.gather( question_span_start_log_probs, 1, clamped_gold_question_span_starts) log_likelihood_for_question_span_ends = torch.gather( question_span_end_log_probs, 1, clamped_gold_question_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_spans = log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_question_spans = util.replace_masked_values( log_likelihood_for_question_spans, gold_question_span_mask, -1e7) # Shape: (batch_size, ) # pylint: disable=invalid-name log_marginal_likelihood_for_question_span = util.logsumexp( log_likelihood_for_question_spans) # question multi span prediction log_likelihood_for_question_span_nums = torch.gather( span_num_log_probs, 1, span_num) log_marginal_likelihood_for_question_span_nums = util.logsumexp( log_likelihood_for_question_span_nums) log_marginal_likelihood_list.append( log_marginal_likelihood_for_question_span + log_marginal_likelihood_for_question_span_nums) elif answering_ability == "addition_subtraction": # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here. # Shape: (batch_size, # of combinations) gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1) > 0).float() # Shape: (batch_size, # of numbers in the passage, # of combinations) gold_add_sub_signs = answer_as_add_sub_expressions.transpose( 1, 2) # Shape: (batch_size, # of numbers in the passage, # of combinations) log_likelihood_for_number_signs = torch.gather( number_sign_log_probs, 2, gold_add_sub_signs) # the log likelihood of the masked positions should be 0 # so that it will not affect the joint probability log_likelihood_for_number_signs = util.replace_masked_values( log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0) # Shape: (batch_size, # of combinations) log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum( 1) # For those padded combinations, we set their log probabilities to be very small negative value log_likelihood_for_add_subs = util.replace_masked_values( log_likelihood_for_add_subs, gold_add_sub_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_add_sub = util.logsumexp( log_likelihood_for_add_subs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_add_sub) elif answering_ability == "counting": # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = (answer_as_counts != -1).long() # Shape: (batch_size, # of count answers) clamped_gold_counts = util.replace_masked_values( answer_as_counts, gold_count_mask, 0) log_likelihood_for_counts = torch.gather( count_number_log_probs, 1, clamped_gold_counts) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = util.replace_masked_values( log_likelihood_for_counts, gold_count_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_count = util.logsumexp( log_likelihood_for_counts) log_marginal_likelihood_list.append( log_marginal_likelihood_for_count) else: raise ValueError( f"Unsupported answering ability: {answering_ability}") # print(log_marginal_likelihood_list) if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack( log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs marginal_log_likelihood = util.logsumexp( all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = -marginal_log_likelihood.mean() if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] for i in range(batch_size): if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[ best_answer_ability[i].detach().cpu().numpy()] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} question_start = 1 passage_start = len(metadata[i]["question_tokens"]) + 2 # We did not consider multi-mention answers here if predicted_ability_str == "passage_span_extraction": answer_json["answer_type"] = "passage_span" passage_str = metadata[i]['original_passage'] offsets = metadata[i]['passage_token_offsets'] predicted_answer, predicted_spans = best_answers_extraction( best_passage_span[i], best_span_number[i], passage_str, offsets, passage_start) answer_json["value"] = predicted_answer answer_json["spans"] = predicted_spans elif predicted_ability_str == "question_span_extraction": answer_json["answer_type"] = "question_span" question_str = metadata[i]['original_question'] offsets = metadata[i]['question_token_offsets'] predicted_answer, predicted_spans = best_answers_extraction( best_question_span[i], best_span_number[i], question_str, offsets, question_start) answer_json["value"] = predicted_answer answer_json["spans"] = predicted_spans elif predicted_ability_str == "addition_subtraction": answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]['original_numbers'] sign_remap = {0: 0, 1: 1, 2: -1} predicted_signs = [ sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy() ] result = sum([ sign * number for sign, number in zip( predicted_signs, original_numbers) ]) predicted_answer = convert_number_to_str(result) offsets = metadata[i]['passage_token_offsets'] number_indices = metadata[i]['number_indices'] number_positions = [ offsets[index - 1] for index in number_indices ] answer_json['numbers'] = [] for offset, value, sign in zip(number_positions, original_numbers, predicted_signs): answer_json['numbers'].append({ 'span': offset, 'value': value, 'sign': sign }) if number_indices[-1] == -1: # There is a dummy 0 number at position -1 added in some cases; we are # removing that here. answer_json["numbers"].pop() answer_json["value"] = result answer_json[ 'number_sign_log_probs'] = number_sign_log_probs[ i, :, :].detach().cpu().numpy() elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" predicted_count = best_count_number[i].detach().cpu( ).numpy() predicted_answer = str(predicted_count) answer_json["count"] = predicted_count else: raise ValueError( f"Unsupported answer ability: {predicted_ability_str}") answer_json["predicted_answer"] = predicted_answer output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(answer_json) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(predicted_answer, answer_annotations) if self.use_gcn: output_dict['clamped_number_indices'] = clamped_number_indices output_dict['node_weight'] = d_node_weight return output_dict
def forward(self, d_node, q_node, d_node_mask, q_node_mask, graph, extra_factor=None): d_node_len = d_node.size(1) q_node_len = q_node.size(1) diagmat = torch.diagflat( torch.ones(d_node.size(1), dtype=torch.long, device=d_node.device)) diagmat = diagmat.unsqueeze(0).expand(d_node.size(0), -1, -1) dd_graph = d_node_mask.unsqueeze(1) * d_node_mask.unsqueeze(-1) * ( 1 - diagmat) dd_graph_left = dd_graph * graph[:, :d_node_len, :d_node_len] dd_graph_right = dd_graph * (1 - graph[:, :d_node_len, :d_node_len]) diagmat = torch.diagflat( torch.ones(q_node.size(1), dtype=torch.long, device=q_node.device)) diagmat = diagmat.unsqueeze(0).expand(q_node.size(0), -1, -1) qq_graph = q_node_mask.unsqueeze(1) * q_node_mask.unsqueeze(-1) * ( 1 - diagmat) qq_graph_left = qq_graph * graph[:, d_node_len:, d_node_len:] qq_graph_right = qq_graph * (1 - graph[:, d_node_len:, d_node_len:]) dq_graph = d_node_mask.unsqueeze(-1) * q_node_mask.unsqueeze(1) dq_graph_left = dq_graph * graph[:, :d_node_len, d_node_len:] dq_graph_right = dq_graph * (1 - graph[:, :d_node_len, d_node_len:]) qd_graph = q_node_mask.unsqueeze(-1) * d_node_mask.unsqueeze(1) qd_graph_left = qd_graph * graph[:, d_node_len:, :d_node_len] qd_graph_right = qd_graph * (1 - graph[:, d_node_len:, :d_node_len]) d_node_neighbor_num = dd_graph_left.sum(-1) + dd_graph_right.sum( -1) + dq_graph_left.sum(-1) + dq_graph_right.sum(-1) d_node_neighbor_num_mask = (d_node_neighbor_num >= 1).long() d_node_neighbor_num = util.replace_masked_values( d_node_neighbor_num.float(), d_node_neighbor_num_mask, 1) q_node_neighbor_num = qq_graph_left.sum(-1) + qq_graph_right.sum( -1) + qd_graph_left.sum(-1) + qd_graph_right.sum(-1) q_node_neighbor_num_mask = (q_node_neighbor_num >= 1).long() q_node_neighbor_num = util.replace_masked_values( q_node_neighbor_num.float(), q_node_neighbor_num_mask, 1) all_d_weight, all_q_weight = [], [] for step in range(self.iteration_steps): if extra_factor is None: d_node_weight = torch.sigmoid( self._node_weight_fc(d_node)).squeeze(-1) q_node_weight = torch.sigmoid( self._node_weight_fc(q_node)).squeeze(-1) else: d_node_weight = torch.sigmoid( self._node_weight_fc( torch.cat((d_node, extra_factor), dim=-1))).squeeze(-1) q_node_weight = torch.sigmoid( self._node_weight_fc( torch.cat((q_node, extra_factor), dim=-1))).squeeze(-1) all_d_weight.append(d_node_weight) all_q_weight.append(q_node_weight) self_d_node_info = self._self_node_fc(d_node) self_q_node_info = self._self_node_fc(q_node) dd_node_info_left = self._dd_node_fc_left(d_node) qd_node_info_left = self._qd_node_fc_left(d_node) qq_node_info_left = self._qq_node_fc_left(q_node) dq_node_info_left = self._dq_node_fc_left(q_node) dd_node_weight = util.replace_masked_values( d_node_weight.unsqueeze(1).expand(-1, d_node_len, -1), dd_graph_left, 0) qd_node_weight = util.replace_masked_values( d_node_weight.unsqueeze(1).expand(-1, q_node_len, -1), qd_graph_left, 0) qq_node_weight = util.replace_masked_values( q_node_weight.unsqueeze(1).expand(-1, q_node_len, -1), qq_graph_left, 0) dq_node_weight = util.replace_masked_values( q_node_weight.unsqueeze(1).expand(-1, d_node_len, -1), dq_graph_left, 0) dd_node_info_left = torch.matmul(dd_node_weight, dd_node_info_left) qd_node_info_left = torch.matmul(qd_node_weight, qd_node_info_left) qq_node_info_left = torch.matmul(qq_node_weight, qq_node_info_left) dq_node_info_left = torch.matmul(dq_node_weight, dq_node_info_left) dd_node_info_right = self._dd_node_fc_right(d_node) qd_node_info_right = self._qd_node_fc_right(d_node) qq_node_info_right = self._qq_node_fc_right(q_node) dq_node_info_right = self._dq_node_fc_right(q_node) dd_node_weight = util.replace_masked_values( d_node_weight.unsqueeze(1).expand(-1, d_node_len, -1), dd_graph_right, 0) qd_node_weight = util.replace_masked_values( d_node_weight.unsqueeze(1).expand(-1, q_node_len, -1), qd_graph_right, 0) qq_node_weight = util.replace_masked_values( q_node_weight.unsqueeze(1).expand(-1, q_node_len, -1), qq_graph_right, 0) dq_node_weight = util.replace_masked_values( q_node_weight.unsqueeze(1).expand(-1, d_node_len, -1), dq_graph_right, 0) dd_node_info_right = torch.matmul(dd_node_weight, dd_node_info_right) qd_node_info_right = torch.matmul(qd_node_weight, qd_node_info_right) qq_node_info_right = torch.matmul(qq_node_weight, qq_node_info_right) dq_node_info_right = torch.matmul(dq_node_weight, dq_node_info_right) agg_d_node_info = ( dd_node_info_left + dd_node_info_right + dq_node_info_left + dq_node_info_right) / d_node_neighbor_num.unsqueeze(-1) agg_q_node_info = ( qq_node_info_left + qq_node_info_right + qd_node_info_left + qd_node_info_right) / q_node_neighbor_num.unsqueeze(-1) d_node = F.relu(self_d_node_info + agg_d_node_info) q_node = F.relu(self_q_node_info + agg_q_node_info) all_d_weight = [weight.unsqueeze(1) for weight in all_d_weight] all_q_weight = [weight.unsqueeze(1) for weight in all_q_weight] all_d_weight = torch.cat(all_d_weight, dim=1) all_q_weight = torch.cat(all_q_weight, dim=1) return d_node, q_node, all_d_weight, all_q_weight # d_node_weight, q_node_weight
def forward(self, node, node_mask, argument_graph, punctuation_graph, extra_factor=None): ''' ''' ''' Current: 2 relation patterns. - argument edge. (most of them are causal relations) - punctuation edges. (including periods and commas) ''' node_len = node.size(1) diagmat = torch.diagflat(torch.ones(node.size(1), dtype=torch.long, device=node.device)) diagmat = diagmat.unsqueeze(0).expand(node.size(0), -1, -1) dd_graph = node_mask.unsqueeze(1) * node_mask.unsqueeze(-1) * (1 - diagmat) graph_argument = dd_graph * argument_graph graph_punctuation = dd_graph * punctuation_graph node_neighbor_num = graph_argument.sum(-1) + graph_punctuation.sum(-1) node_neighbor_num_mask = (node_neighbor_num >= 1).long() node_neighbor_num = util.replace_masked_values(node_neighbor_num.float(), node_neighbor_num_mask, 1) all_weight = [] for step in range(self.iteration_steps): ''' (1) Node Relatedness Measure ''' if extra_factor is None: d_node_weight = torch.sigmoid(self._node_weight_fc(node)).squeeze( -1) else: d_node_weight = torch.sigmoid(self._node_weight_fc(torch.cat((node, extra_factor), dim=-1))).squeeze( -1) all_weight.append(d_node_weight) self_node_info = self._self_node_fc(node) ''' (2) Message Propagation (each relation type) ''' node_info_argument = self._node_fc_argument(node) node_weight = util.replace_masked_values( d_node_weight.unsqueeze(1).expand(-1, node_len, -1), graph_argument, 0) node_info_argument = torch.matmul(node_weight, node_info_argument) node_info_punctuation = self._node_fc_punctuation(node) node_weight = util.replace_masked_values( d_node_weight.unsqueeze(1).expand(-1, node_len, -1), graph_punctuation, 0) node_info_punctuation = torch.matmul(node_weight, node_info_punctuation) agg_node_info = (node_info_argument + node_info_punctuation) / node_neighbor_num.unsqueeze(-1) ''' (3) Node Representation Update ''' node = F.relu(self_node_info + agg_node_info) all_weight = [weight.unsqueeze(1) for weight in all_weight] all_weight = torch.cat(all_weight, dim=1) return node, all_weight
def log_likelihood(self, answer_as_text_to_disjoint_bios, answer_as_list_of_bios, span_bio_labels, log_probs, logits, seq_mask, wordpiece_mask, is_bio_mask, **kwargs): # answer_as_text_to_disjoint_bios - Shape: (batch_size, # of text answers, # of spans a for text answer, seq_length) # answer_as_list_of_bios - Shape: (batch_size, # of correct sequences, seq_length) # log_probs - Shape: (batch_size, seq_length, 3) # seq_mask - Shape: (batch_size, seq_length) # Generate most likely correct predictions if self._use_crf: raise NotImplementedError else: with torch.no_grad(): answer_as_list_of_bios = answer_as_list_of_bios * seq_mask.unsqueeze( 1) if answer_as_text_to_disjoint_bios.sum() > 0: full_bio = span_bio_labels if self._generation_top_k > 0: most_likely_predictions = self._get_top_k_sequences( log_probs, wordpiece_mask, self._generation_top_k) most_likely_predictions = most_likely_predictions * seq_mask.unsqueeze( 1) generated_list_of_bios = self._filter_correct_predictions( most_likely_predictions, answer_as_text_to_disjoint_bios, full_bio) is_pregenerated_answer_format_mask = ( answer_as_list_of_bios.sum((1, 2)) > 0).unsqueeze(-1).unsqueeze(-1).long() list_of_bios = torch.cat( (answer_as_list_of_bios, (generated_list_of_bios * (1 - is_pregenerated_answer_format_mask))), dim=1) list_of_bios = self._add_full_bio( list_of_bios, full_bio) else: is_pregenerated_answer_format_mask = ( answer_as_list_of_bios.sum((1, 2)) > 0).long() list_of_bios = torch.cat( (answer_as_list_of_bios, (full_bio * (1 - is_pregenerated_answer_format_mask ).unsqueeze(-1)).unsqueeze(1)), dim=1) else: list_of_bios = answer_as_list_of_bios ### Calculate log-likelihood from list_of_bios if self._use_crf: raise NotImplementedError else: log_marginal_likelihood_for_multispan = self._get_combined_likelihood( list_of_bios, log_probs) # For questions without spans, we set their log probabilities to be very small negative value log_marginal_likelihood_for_multispan = \ replace_masked_values(log_marginal_likelihood_for_multispan, is_bio_mask, -1e7) return log_marginal_likelihood_for_multispan
def prediction(self, log_probs, logits, qp_tokens, p_text, q_text, mask): predicted_tags = torch.argmax(logits, dim=-1) predicted_tags = replace_masked_values(predicted_tags, mask, 0) return MultiSpanHead.decode_spans_from_tags(predicted_tags, qp_tokens, p_text, q_text)