def forward( self, # type: ignore input_ids: torch.LongTensor, input_mask: torch.LongTensor, input_segments: torch.LongTensor, passage_mask: torch.LongTensor, question_mask: torch.LongTensor, number_indices: torch.LongTensor, passage_number_order: torch.LongTensor, question_number_order: torch.LongTensor, question_number_indices: torch.LongTensor, answer_as_passage_spans: torch.LongTensor = None, answer_as_question_spans: torch.LongTensor = None, answer_as_add_sub_expressions: torch.LongTensor = None, answer_as_counts: torch.LongTensor = None, span_num: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: # sequence_output, _, other_sequence_output = self.bert(input_ids, input_segments, input_mask) outputs = self.bert(input_ids, attention_mask=input_mask, token_type_ids=input_segments) sequence_output = outputs[0] sequence_output_list = [item for item in outputs[2][-4:]] batch_size = input_ids.size(0) if ("passage_span_extraction" in self.answering_abilities or "question_span" in self.answering_abilities) and self.use_gcn: # M2, M3 sequence_alg = self._gcn_input_proj( torch.cat([sequence_output_list[2], sequence_output_list[3]], dim=2)) encoded_passage_for_numbers = sequence_alg encoded_question_for_numbers = sequence_alg # passage number extraction real_number_indices = number_indices - 1 number_mask = (real_number_indices > -1).long() # ?? clamped_number_indices = util.replace_masked_values( real_number_indices, number_mask, 0) encoded_numbers = torch.gather( encoded_passage_for_numbers, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, encoded_passage_for_numbers.size(-1))) # question number extraction question_number_mask = (question_number_indices > -1).long() clamped_question_number_indices = util.replace_masked_values( question_number_indices, question_number_mask, 0) question_encoded_number = torch.gather( encoded_question_for_numbers, 1, clamped_question_number_indices.unsqueeze(-1).expand( -1, -1, encoded_question_for_numbers.size(-1))) # graph mask number_order = torch.cat( (passage_number_order, question_number_order), -1) new_graph_mask = number_order.unsqueeze(1).expand( batch_size, number_order.size(-1), -1) > number_order.unsqueeze(-1).expand( batch_size, -1, number_order.size(-1)) new_graph_mask = new_graph_mask.long() all_number_mask = torch.cat((number_mask, question_number_mask), dim=-1) new_graph_mask = all_number_mask.unsqueeze( 1) * all_number_mask.unsqueeze(-1) * new_graph_mask # iteration d_node, q_node, d_node_weight, _ = self._gcn( d_node=encoded_numbers, q_node=question_encoded_number, d_node_mask=number_mask, q_node_mask=question_number_mask, graph=new_graph_mask) gcn_info_vec = torch.zeros((batch_size, sequence_alg.size(1) + 1, sequence_output_list[-1].size(-1)), dtype=torch.float, device=d_node.device) clamped_number_indices = util.replace_masked_values( real_number_indices, number_mask, gcn_info_vec.size(1) - 1) gcn_info_vec.scatter_( 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, d_node.size(-1)), d_node) gcn_info_vec = gcn_info_vec[:, :-1, :] sequence_output_list[2] = self._gcn_enc( self._proj_ln(sequence_output_list[2] + gcn_info_vec)) sequence_output_list[0] = self._gcn_enc( self._proj_ln0(sequence_output_list[0] + gcn_info_vec)) sequence_output_list[1] = self._gcn_enc( self._proj_ln1(sequence_output_list[1] + gcn_info_vec)) sequence_output_list[3] = self._gcn_enc( self._proj_ln3(sequence_output_list[3] + gcn_info_vec)) # passage hidden and question hidden sequence_h2_weight = self._proj_sequence_h( sequence_output_list[2]).squeeze(-1) passage_h2_weight = util.masked_softmax(sequence_h2_weight, passage_mask) passage_h2 = util.weighted_sum(sequence_output_list[2], passage_h2_weight) question_h2_weight = util.masked_softmax(sequence_h2_weight, question_mask) question_h2 = util.weighted_sum(sequence_output_list[2], question_h2_weight) # passage g0, g1, g2 question_g0_weight = self._proj_sequence_g0( sequence_output_list[0]).squeeze(-1) question_g0_weight = util.masked_softmax(question_g0_weight, question_mask) question_g0 = util.weighted_sum(sequence_output_list[0], question_g0_weight) question_g1_weight = self._proj_sequence_g1( sequence_output_list[1]).squeeze(-1) question_g1_weight = util.masked_softmax(question_g1_weight, question_mask) question_g1 = util.weighted_sum(sequence_output_list[1], question_g1_weight) question_g2_weight = self._proj_sequence_g2( sequence_output_list[2]).squeeze(-1) question_g2_weight = util.masked_softmax(question_g2_weight, question_mask) question_g2 = util.weighted_sum(sequence_output_list[2], question_g2_weight) if len(self.answering_abilities) > 1: # Shape: (batch_size, number_of_abilities) answer_ability_logits = self._answer_ability_predictor( torch.cat([passage_h2, question_h2, sequence_output[:, 0]], 1)) answer_ability_log_probs = F.log_softmax(answer_ability_logits, -1) best_answer_ability = torch.argmax(answer_ability_log_probs, 1) real_number_indices = number_indices.squeeze(-1) - 1 number_mask = (real_number_indices > -1).long() clamped_number_indices = util.replace_masked_values( real_number_indices, number_mask, 0) encoded_passage_for_numbers = torch.cat( [sequence_output_list[2], sequence_output_list[3]], dim=-1) encoded_numbers = torch.gather( encoded_passage_for_numbers, 1, clamped_number_indices.unsqueeze(-1).expand( -1, -1, encoded_passage_for_numbers.size(-1))) number_weight = self._proj_number(encoded_numbers).squeeze(-1) number_mask = (number_indices > -1).long() number_weight = util.masked_softmax(number_weight, number_mask) number_vector = util.weighted_sum(encoded_numbers, number_weight) if "counting" in self.answering_abilities: # Shape: (batch_size, 10) count_number_logits = self._count_number_predictor( torch.cat([ number_vector, passage_h2, question_h2, sequence_output[:, 0] ], dim=1)) count_number_log_probs = torch.nn.functional.log_softmax( count_number_logits, -1) # Info about the best count number prediction # Shape: (batch_size,) best_count_number = torch.argmax(count_number_log_probs, -1) best_count_log_prob = torch.gather( count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1) if len(self.answering_abilities) > 1: best_count_log_prob += answer_ability_log_probs[:, self. _counting_index] if "passage_span_extraction" in self.answering_abilities or "question_span_extraction" in self.answering_abilities: # start 0, 2 sequence_for_span_start = torch.cat([ sequence_output_list[2], sequence_output_list[0], sequence_output_list[2] * question_g2.unsqueeze(1), sequence_output_list[0] * question_g0.unsqueeze(1) ], dim=2) sequence_span_start_logits = self._span_start_predictor( sequence_for_span_start).squeeze(-1) # Shape: (batch_size, passage_length, modeling_dim * 2) sequence_for_span_end = torch.cat([ sequence_output_list[2], sequence_output_list[1], sequence_output_list[2] * question_g2.unsqueeze(1), sequence_output_list[1] * question_g1.unsqueeze(1) ], dim=2) # Shape: (batch_size, passage_length) sequence_span_end_logits = self._span_end_predictor( sequence_for_span_end).squeeze(-1) # Shape: (batch_size, passage_length) # span number prediction span_num_logits = self._proj_span_num( torch.cat([passage_h2, question_h2, sequence_output[:, 0]], dim=1)) span_num_log_probs = torch.nn.functional.log_softmax( span_num_logits, -1) best_span_number = torch.argmax(span_num_log_probs, dim=-1) if "passage_span_extraction" in self.answering_abilities: passage_span_start_log_probs = util.masked_log_softmax( sequence_span_start_logits, passage_mask) passage_span_end_log_probs = util.masked_log_softmax( sequence_span_end_logits, passage_mask) # Info about the best passage span prediction passage_span_start_logits = util.replace_masked_values( sequence_span_start_logits, passage_mask, -1e7) passage_span_end_logits = util.replace_masked_values( sequence_span_end_logits, passage_mask, -1e7) # Shage: (batch_size, topk, 2) best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits) if "question_span_extraction" in self.answering_abilities: question_span_start_log_probs = util.masked_log_softmax( sequence_span_start_logits, question_mask) question_span_end_log_probs = util.masked_log_softmax( sequence_span_end_logits, question_mask) # Info about the best question span prediction question_span_start_logits = util.replace_masked_values( sequence_span_start_logits, question_mask, -1e7) question_span_end_logits = util.replace_masked_values( sequence_span_end_logits, question_mask, -1e7) # Shape: (batch_size, topk, 2) best_question_span = get_best_span(question_span_start_logits, question_span_end_logits) if "addition_subtraction" in self.answering_abilities: alg_encoded_numbers = torch.cat([ encoded_numbers, question_h2.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1), passage_h2.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1), sequence_output[:, 0].unsqueeze(1).repeat( 1, encoded_numbers.size(1), 1) ], 2) # Shape: (batch_size, # of numbers in the passage, 3) number_sign_logits = self._number_sign_predictor( alg_encoded_numbers) number_sign_log_probs = torch.nn.functional.log_softmax( number_sign_logits, -1) # Shape: (batch_size, # of numbers in passage). best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1) # For padding numbers, the best sign masked as 0 (not included). best_signs_for_numbers = util.replace_masked_values( best_signs_for_numbers, number_mask, 0) # Shape: (batch_size, # of numbers in passage) best_signs_log_probs = torch.gather( number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1) # the probs of the masked positions should be 1 so that it will not affect the joint probability # TODO: this is not quite right, since if there are many numbers in the passage, # TODO: the joint probability would be very small. best_signs_log_probs = util.replace_masked_values( best_signs_log_probs, number_mask, 0) # Shape: (batch_size,) best_combination_log_prob = best_signs_log_probs.sum(-1) if len(self.answering_abilities) > 1: best_combination_log_prob += answer_ability_log_probs[:, self. _addition_subtraction_index] output_dict = {} # If answer is given, compute the loss. if answer_as_passage_spans is not None or answer_as_question_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None: log_marginal_likelihood_list = [] for answering_ability in self.answering_abilities: if answering_ability == "passage_span_extraction": # Shape: (batch_size, # of answer spans) gold_passage_span_starts = answer_as_passage_spans[:, :, 0] gold_passage_span_ends = answer_as_passage_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_passage_span_mask = (gold_passage_span_starts != -1).long() clamped_gold_passage_span_starts = util.replace_masked_values( gold_passage_span_starts, gold_passage_span_mask, 0) clamped_gold_passage_span_ends = util.replace_masked_values( gold_passage_span_ends, gold_passage_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_span_starts = torch.gather( passage_span_start_log_probs, 1, clamped_gold_passage_span_starts) log_likelihood_for_passage_span_ends = torch.gather( passage_span_end_log_probs, 1, clamped_gold_passage_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_passage_spans = log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_passage_spans = util.replace_masked_values( log_likelihood_for_passage_spans, gold_passage_span_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_passage_span = util.logsumexp( log_likelihood_for_passage_spans) # span log probabilities log_likelihood_for_passage_span_nums = torch.gather( span_num_log_probs, 1, span_num) log_likelihood_for_passage_span_nums = util.replace_masked_values( log_likelihood_for_passage_span_nums, gold_passage_span_mask[:, :1], -1e7) log_marginal_likelihood_for_passage_span_nums = util.logsumexp( log_likelihood_for_passage_span_nums) log_marginal_likelihood_list.append( log_marginal_likelihood_for_passage_span + log_marginal_likelihood_for_passage_span_nums) elif answering_ability == "question_span_extraction": # Shape: (batch_size, # of answer spans) gold_question_span_starts = answer_as_question_spans[:, :, 0] gold_question_span_ends = answer_as_question_spans[:, :, 1] # Some spans are padded with index -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. gold_question_span_mask = (gold_question_span_starts != -1).long() clamped_gold_question_span_starts = util.replace_masked_values( gold_question_span_starts, gold_question_span_mask, 0) clamped_gold_question_span_ends = util.replace_masked_values( gold_question_span_ends, gold_question_span_mask, 0) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_span_starts = torch.gather( question_span_start_log_probs, 1, clamped_gold_question_span_starts) log_likelihood_for_question_span_ends = torch.gather( question_span_end_log_probs, 1, clamped_gold_question_span_ends) # Shape: (batch_size, # of answer spans) log_likelihood_for_question_spans = log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_question_spans = util.replace_masked_values( log_likelihood_for_question_spans, gold_question_span_mask, -1e7) # Shape: (batch_size, ) # pylint: disable=invalid-name log_marginal_likelihood_for_question_span = util.logsumexp( log_likelihood_for_question_spans) # question multi span prediction log_likelihood_for_question_span_nums = torch.gather( span_num_log_probs, 1, span_num) log_marginal_likelihood_for_question_span_nums = util.logsumexp( log_likelihood_for_question_span_nums) log_marginal_likelihood_list.append( log_marginal_likelihood_for_question_span + log_marginal_likelihood_for_question_span_nums) elif answering_ability == "addition_subtraction": # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here. # Shape: (batch_size, # of combinations) gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1) > 0).float() # Shape: (batch_size, # of numbers in the passage, # of combinations) gold_add_sub_signs = answer_as_add_sub_expressions.transpose( 1, 2) # Shape: (batch_size, # of numbers in the passage, # of combinations) log_likelihood_for_number_signs = torch.gather( number_sign_log_probs, 2, gold_add_sub_signs) # the log likelihood of the masked positions should be 0 # so that it will not affect the joint probability log_likelihood_for_number_signs = util.replace_masked_values( log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0) # Shape: (batch_size, # of combinations) log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum( 1) # For those padded combinations, we set their log probabilities to be very small negative value log_likelihood_for_add_subs = util.replace_masked_values( log_likelihood_for_add_subs, gold_add_sub_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_add_sub = util.logsumexp( log_likelihood_for_add_subs) log_marginal_likelihood_list.append( log_marginal_likelihood_for_add_sub) elif answering_ability == "counting": # Count answers are padded with label -1, # so we clamp those paddings to 0 and then mask after `torch.gather()`. # Shape: (batch_size, # of count answers) gold_count_mask = (answer_as_counts != -1).long() # Shape: (batch_size, # of count answers) clamped_gold_counts = util.replace_masked_values( answer_as_counts, gold_count_mask, 0) log_likelihood_for_counts = torch.gather( count_number_log_probs, 1, clamped_gold_counts) # For those padded spans, we set their log probabilities to be very small negative value log_likelihood_for_counts = util.replace_masked_values( log_likelihood_for_counts, gold_count_mask, -1e7) # Shape: (batch_size, ) log_marginal_likelihood_for_count = util.logsumexp( log_likelihood_for_counts) log_marginal_likelihood_list.append( log_marginal_likelihood_for_count) else: raise ValueError( f"Unsupported answering ability: {answering_ability}") # print(log_marginal_likelihood_list) if len(self.answering_abilities) > 1: # Add the ability probabilities if there are more than one abilities all_log_marginal_likelihoods = torch.stack( log_marginal_likelihood_list, dim=-1) all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs marginal_log_likelihood = util.logsumexp( all_log_marginal_likelihoods) else: marginal_log_likelihood = log_marginal_likelihood_list[0] output_dict["loss"] = -marginal_log_likelihood.mean() if metadata is not None: output_dict["question_id"] = [] output_dict["answer"] = [] for i in range(batch_size): if len(self.answering_abilities) > 1: predicted_ability_str = self.answering_abilities[ best_answer_ability[i].detach().cpu().numpy()] else: predicted_ability_str = self.answering_abilities[0] answer_json: Dict[str, Any] = {} question_start = 1 passage_start = len(metadata[i]["question_tokens"]) + 2 # We did not consider multi-mention answers here if predicted_ability_str == "passage_span_extraction": answer_json["answer_type"] = "passage_span" passage_str = metadata[i]['original_passage'] offsets = metadata[i]['passage_token_offsets'] predicted_answer, predicted_spans = best_answers_extraction( best_passage_span[i], best_span_number[i], passage_str, offsets, passage_start) answer_json["value"] = predicted_answer answer_json["spans"] = predicted_spans elif predicted_ability_str == "question_span_extraction": answer_json["answer_type"] = "question_span" question_str = metadata[i]['original_question'] offsets = metadata[i]['question_token_offsets'] predicted_answer, predicted_spans = best_answers_extraction( best_question_span[i], best_span_number[i], question_str, offsets, question_start) answer_json["value"] = predicted_answer answer_json["spans"] = predicted_spans elif predicted_ability_str == "addition_subtraction": answer_json["answer_type"] = "arithmetic" original_numbers = metadata[i]['original_numbers'] sign_remap = {0: 0, 1: 1, 2: -1} predicted_signs = [ sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy() ] result = sum([ sign * number for sign, number in zip( predicted_signs, original_numbers) ]) predicted_answer = convert_number_to_str(result) offsets = metadata[i]['passage_token_offsets'] number_indices = metadata[i]['number_indices'] number_positions = [ offsets[index - 1] for index in number_indices ] answer_json['numbers'] = [] for offset, value, sign in zip(number_positions, original_numbers, predicted_signs): answer_json['numbers'].append({ 'span': offset, 'value': value, 'sign': sign }) if number_indices[-1] == -1: # There is a dummy 0 number at position -1 added in some cases; we are # removing that here. answer_json["numbers"].pop() answer_json["value"] = result answer_json[ 'number_sign_log_probs'] = number_sign_log_probs[ i, :, :].detach().cpu().numpy() elif predicted_ability_str == "counting": answer_json["answer_type"] = "count" predicted_count = best_count_number[i].detach().cpu( ).numpy() predicted_answer = str(predicted_count) answer_json["count"] = predicted_count else: raise ValueError( f"Unsupported answer ability: {predicted_ability_str}") answer_json["predicted_answer"] = predicted_answer output_dict["question_id"].append(metadata[i]["question_id"]) output_dict["answer"].append(answer_json) answer_annotations = metadata[i].get('answer_annotations', []) if answer_annotations: self._drop_metrics(predicted_answer, answer_annotations) if self.use_gcn: output_dict['clamped_number_indices'] = clamped_number_indices output_dict['node_weight'] = d_node_weight return output_dict
def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, passage_mask: torch.LongTensor, question_mask: torch.LongTensor, argument_bpe_ids: torch.LongTensor, domain_bpe_ids: torch.LongTensor, punct_bpe_ids: torch.LongTensor, labels: torch.LongTensor, token_type_ids: torch.LongTensor = None, ) -> Tuple: num_choices = input_ids.shape[1] flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_passage_mask = passage_mask.view(-1, passage_mask.size(-1)) if passage_mask is not None else None flat_question_mask = question_mask.view(-1, question_mask.size(-1)) if question_mask is not None else None flat_argument_bpe_ids = argument_bpe_ids.view(-1, argument_bpe_ids.size(-1)) if argument_bpe_ids is not None else None flat_domain_bpe_ids = domain_bpe_ids.view(-1, domain_bpe_ids.size(-1)) if domain_bpe_ids is not None else None flat_punct_bpe_ids = punct_bpe_ids.view(-1, punct_bpe_ids.size(-1)) if punct_bpe_ids is not None else None bert_outputs = self.roberta(flat_input_ids, attention_mask=flat_attention_mask, token_type_ids=None) sequence_output = bert_outputs[0] pooled_output = bert_outputs[1] if self.use_gcn: ''' The GCN branch. Suppose to go back to baseline once remove. ''' new_punct_id = self.max_rel_id + 1 new_punct_bpe_ids = new_punct_id * flat_punct_bpe_ids # punct_id: 1 -> 4. for incorporating with argument_bpe_ids. _flat_all_bpe_ids = flat_argument_bpe_ids + new_punct_bpe_ids # -1:padding, 0:non, 1-3: arg, 4:punct. overlapped_punct_argument_mask = (_flat_all_bpe_ids > new_punct_id).long() flat_all_bpe_ids = _flat_all_bpe_ids * (1 - overlapped_punct_argument_mask) + flat_argument_bpe_ids * overlapped_punct_argument_mask assert flat_argument_bpe_ids.max().item() <= new_punct_id # encoded_spans: (bsz x n_choices, n_nodes, embed_size) # span_mask: (bsz x n_choices, n_nodes) # edges: list[list[int]] # node_in_seq_indices: list[list[list[int]]] encoded_spans, span_mask, edges, node_in_seq_indices = self.split_into_spans_9(sequence_output, flat_attention_mask, flat_all_bpe_ids) argument_graph, punctuation_graph = self.get_adjacency_matrices_2(edges, n_nodes=encoded_spans.size(1), device=encoded_spans.device) node, node_weight = self._gcn(node=encoded_spans, node_mask=span_mask, argument_graph=argument_graph, punctuation_graph=punctuation_graph) gcn_info_vec = self.get_gcn_info_vector(node_in_seq_indices, node, size=sequence_output.size(), device=sequence_output.device) gcn_updated_sequence_output = self._gcn_enc(self._gcn_prj_ln(sequence_output + gcn_info_vec)) # passage hidden and question hidden sequence_h2_weight = self._proj_sequence_h(gcn_updated_sequence_output).squeeze(-1) passage_h2_weight = util.masked_softmax(sequence_h2_weight.float(), flat_passage_mask.float()) passage_h2 = util.weighted_sum(gcn_updated_sequence_output, passage_h2_weight) question_h2_weight = util.masked_softmax(sequence_h2_weight.float(), flat_question_mask.float()) question_h2 = util.weighted_sum(gcn_updated_sequence_output, question_h2_weight) gcn_output_feats = torch.cat([passage_h2, question_h2, gcn_updated_sequence_output[:, 0]], dim=1) gcn_logits = self._proj_span_num(gcn_output_feats) if self.use_pool: ''' The baseline branch. The output. ''' pooled_output = self.dropout(pooled_output) baseline_logits = self.classifier(pooled_output) if self.use_gcn and self.use_pool: ''' Merge gcn_logits & baseline_logits. TODO: different way of merging. ''' if self.merge_type == 1: logits = gcn_logits + baseline_logits elif self.merge_type == 2: pooled_output = self.dropout(pooled_output) merged_feats = torch.cat([gcn_updated_sequence_output[:, 0], pooled_output], dim=1) logits = self._proj_gcn_pool_3(merged_feats) elif self.merge_type == 3: pooled_output = self.dropout(pooled_output) merged_feats = torch.cat([gcn_updated_sequence_output[:, 0], pooled_output, gcn_updated_sequence_output[:, 0], pooled_output], dim=1) logits = self._proj_gcn_pool_4(merged_feats) elif self.merge_type == 4: pooled_output = self.dropout(pooled_output) merged_feats = torch.cat([passage_h2, question_h2, pooled_output], dim=1) logits = self._proj_gcn_pool(merged_feats) elif self.merge_type == 5: pooled_output = self.dropout(pooled_output) merged_feats = torch.cat([passage_h2, question_h2, gcn_updated_sequence_output[:, 0], pooled_output], dim=1) logits = self._proj_gcn_pool_4(merged_feats) elif self.use_gcn: logits = gcn_logits elif self.use_pool: logits = baseline_logits else: raise Exception reshaped_logits = logits.squeeze(-1).view(-1, num_choices) outputs = (reshaped_logits, ) if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) outputs = (loss,) + outputs return outputs