コード例 #1
0
    def forward(
            self,  # type: ignore
            input_ids: torch.LongTensor,
            input_mask: torch.LongTensor,
            input_segments: torch.LongTensor,
            passage_mask: torch.LongTensor,
            question_mask: torch.LongTensor,
            number_indices: torch.LongTensor,
            passage_number_order: torch.LongTensor,
            question_number_order: torch.LongTensor,
            question_number_indices: torch.LongTensor,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            answer_as_add_sub_expressions: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            span_num: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # sequence_output, _, other_sequence_output = self.bert(input_ids, input_segments, input_mask)
        outputs = self.bert(input_ids,
                            attention_mask=input_mask,
                            token_type_ids=input_segments)
        sequence_output = outputs[0]
        sequence_output_list = [item for item in outputs[2][-4:]]

        batch_size = input_ids.size(0)
        if ("passage_span_extraction" in self.answering_abilities or
                "question_span" in self.answering_abilities) and self.use_gcn:
            # M2, M3
            sequence_alg = self._gcn_input_proj(
                torch.cat([sequence_output_list[2], sequence_output_list[3]],
                          dim=2))
            encoded_passage_for_numbers = sequence_alg
            encoded_question_for_numbers = sequence_alg
            # passage number extraction
            real_number_indices = number_indices - 1
            number_mask = (real_number_indices > -1).long()  # ??
            clamped_number_indices = util.replace_masked_values(
                real_number_indices, number_mask, 0)
            encoded_numbers = torch.gather(
                encoded_passage_for_numbers, 1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_passage_for_numbers.size(-1)))

            # question number extraction
            question_number_mask = (question_number_indices > -1).long()
            clamped_question_number_indices = util.replace_masked_values(
                question_number_indices, question_number_mask, 0)
            question_encoded_number = torch.gather(
                encoded_question_for_numbers, 1,
                clamped_question_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_question_for_numbers.size(-1)))

            # graph mask
            number_order = torch.cat(
                (passage_number_order, question_number_order), -1)
            new_graph_mask = number_order.unsqueeze(1).expand(
                batch_size, number_order.size(-1),
                -1) > number_order.unsqueeze(-1).expand(
                    batch_size, -1, number_order.size(-1))
            new_graph_mask = new_graph_mask.long()
            all_number_mask = torch.cat((number_mask, question_number_mask),
                                        dim=-1)
            new_graph_mask = all_number_mask.unsqueeze(
                1) * all_number_mask.unsqueeze(-1) * new_graph_mask

            # iteration
            d_node, q_node, d_node_weight, _ = self._gcn(
                d_node=encoded_numbers,
                q_node=question_encoded_number,
                d_node_mask=number_mask,
                q_node_mask=question_number_mask,
                graph=new_graph_mask)
            gcn_info_vec = torch.zeros((batch_size, sequence_alg.size(1) + 1,
                                        sequence_output_list[-1].size(-1)),
                                       dtype=torch.float,
                                       device=d_node.device)
            clamped_number_indices = util.replace_masked_values(
                real_number_indices, number_mask,
                gcn_info_vec.size(1) - 1)
            gcn_info_vec.scatter_(
                1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, d_node.size(-1)), d_node)
            gcn_info_vec = gcn_info_vec[:, :-1, :]

            sequence_output_list[2] = self._gcn_enc(
                self._proj_ln(sequence_output_list[2] + gcn_info_vec))
            sequence_output_list[0] = self._gcn_enc(
                self._proj_ln0(sequence_output_list[0] + gcn_info_vec))
            sequence_output_list[1] = self._gcn_enc(
                self._proj_ln1(sequence_output_list[1] + gcn_info_vec))
            sequence_output_list[3] = self._gcn_enc(
                self._proj_ln3(sequence_output_list[3] + gcn_info_vec))

        # passage hidden and question hidden
        sequence_h2_weight = self._proj_sequence_h(
            sequence_output_list[2]).squeeze(-1)
        passage_h2_weight = util.masked_softmax(sequence_h2_weight,
                                                passage_mask)
        passage_h2 = util.weighted_sum(sequence_output_list[2],
                                       passage_h2_weight)
        question_h2_weight = util.masked_softmax(sequence_h2_weight,
                                                 question_mask)
        question_h2 = util.weighted_sum(sequence_output_list[2],
                                        question_h2_weight)

        # passage g0, g1, g2
        question_g0_weight = self._proj_sequence_g0(
            sequence_output_list[0]).squeeze(-1)
        question_g0_weight = util.masked_softmax(question_g0_weight,
                                                 question_mask)
        question_g0 = util.weighted_sum(sequence_output_list[0],
                                        question_g0_weight)

        question_g1_weight = self._proj_sequence_g1(
            sequence_output_list[1]).squeeze(-1)
        question_g1_weight = util.masked_softmax(question_g1_weight,
                                                 question_mask)
        question_g1 = util.weighted_sum(sequence_output_list[1],
                                        question_g1_weight)

        question_g2_weight = self._proj_sequence_g2(
            sequence_output_list[2]).squeeze(-1)
        question_g2_weight = util.masked_softmax(question_g2_weight,
                                                 question_mask)
        question_g2 = util.weighted_sum(sequence_output_list[2],
                                        question_g2_weight)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = self._answer_ability_predictor(
                torch.cat([passage_h2, question_h2, sequence_output[:, 0]], 1))
            answer_ability_log_probs = F.log_softmax(answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        real_number_indices = number_indices.squeeze(-1) - 1
        number_mask = (real_number_indices > -1).long()
        clamped_number_indices = util.replace_masked_values(
            real_number_indices, number_mask, 0)
        encoded_passage_for_numbers = torch.cat(
            [sequence_output_list[2], sequence_output_list[3]], dim=-1)
        encoded_numbers = torch.gather(
            encoded_passage_for_numbers, 1,
            clamped_number_indices.unsqueeze(-1).expand(
                -1, -1, encoded_passage_for_numbers.size(-1)))
        number_weight = self._proj_number(encoded_numbers).squeeze(-1)
        number_mask = (number_indices > -1).long()
        number_weight = util.masked_softmax(number_weight, number_mask)
        number_vector = util.weighted_sum(encoded_numbers, number_weight)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(
                torch.cat([
                    number_vector, passage_h2, question_h2, sequence_output[:,
                                                                            0]
                ],
                          dim=1))
            count_number_log_probs = torch.nn.functional.log_softmax(
                count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = torch.gather(
                count_number_log_probs, 1,
                best_count_number.unsqueeze(-1)).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self.
                                                                _counting_index]

        if "passage_span_extraction" in self.answering_abilities or "question_span_extraction" in self.answering_abilities:
            # start 0, 2
            sequence_for_span_start = torch.cat([
                sequence_output_list[2], sequence_output_list[0],
                sequence_output_list[2] * question_g2.unsqueeze(1),
                sequence_output_list[0] * question_g0.unsqueeze(1)
            ],
                                                dim=2)
            sequence_span_start_logits = self._span_start_predictor(
                sequence_for_span_start).squeeze(-1)
            # Shape: (batch_size, passage_length, modeling_dim * 2)
            sequence_for_span_end = torch.cat([
                sequence_output_list[2], sequence_output_list[1],
                sequence_output_list[2] * question_g2.unsqueeze(1),
                sequence_output_list[1] * question_g1.unsqueeze(1)
            ],
                                              dim=2)
            # Shape: (batch_size, passage_length)
            sequence_span_end_logits = self._span_end_predictor(
                sequence_for_span_end).squeeze(-1)
            # Shape: (batch_size, passage_length)

            # span number prediction
            span_num_logits = self._proj_span_num(
                torch.cat([passage_h2, question_h2, sequence_output[:, 0]],
                          dim=1))
            span_num_log_probs = torch.nn.functional.log_softmax(
                span_num_logits, -1)

            best_span_number = torch.argmax(span_num_log_probs, dim=-1)

            if "passage_span_extraction" in self.answering_abilities:
                passage_span_start_log_probs = util.masked_log_softmax(
                    sequence_span_start_logits, passage_mask)
                passage_span_end_log_probs = util.masked_log_softmax(
                    sequence_span_end_logits, passage_mask)

                # Info about the best passage span prediction
                passage_span_start_logits = util.replace_masked_values(
                    sequence_span_start_logits, passage_mask, -1e7)
                passage_span_end_logits = util.replace_masked_values(
                    sequence_span_end_logits, passage_mask, -1e7)
                # Shage: (batch_size, topk, 2)
                best_passage_span = get_best_span(passage_span_start_logits,
                                                  passage_span_end_logits)

            if "question_span_extraction" in self.answering_abilities:
                question_span_start_log_probs = util.masked_log_softmax(
                    sequence_span_start_logits, question_mask)
                question_span_end_log_probs = util.masked_log_softmax(
                    sequence_span_end_logits, question_mask)

                # Info about the best question span prediction
                question_span_start_logits = util.replace_masked_values(
                    sequence_span_start_logits, question_mask, -1e7)
                question_span_end_logits = util.replace_masked_values(
                    sequence_span_end_logits, question_mask, -1e7)
                # Shape: (batch_size, topk, 2)
                best_question_span = get_best_span(question_span_start_logits,
                                                   question_span_end_logits)

        if "addition_subtraction" in self.answering_abilities:
            alg_encoded_numbers = torch.cat([
                encoded_numbers,
                question_h2.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1),
                passage_h2.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1),
                sequence_output[:, 0].unsqueeze(1).repeat(
                    1, encoded_numbers.size(1), 1)
            ], 2)

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(
                alg_encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(
                number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(
                best_signs_for_numbers, number_mask, 0)
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                number_sign_log_probs, 2,
                best_signs_for_numbers.unsqueeze(-1)).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(
                best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)
            best_combination_log_prob = best_signs_log_probs.sum(-1)
            if len(self.answering_abilities) > 1:
                best_combination_log_prob += answer_ability_log_probs[:, self.
                                                                      _addition_subtraction_index]

        output_dict = {}

        # If answer is given, compute the loss.
        if answer_as_passage_spans is not None or answer_as_question_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None:

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
                    gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_passage_span_mask = (gold_passage_span_starts !=
                                              -1).long()
                    clamped_gold_passage_span_starts = util.replace_masked_values(
                        gold_passage_span_starts, gold_passage_span_mask, 0)
                    clamped_gold_passage_span_ends = util.replace_masked_values(
                        gold_passage_span_ends, gold_passage_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_span_starts = torch.gather(
                        passage_span_start_log_probs, 1,
                        clamped_gold_passage_span_starts)
                    log_likelihood_for_passage_span_ends = torch.gather(
                        passage_span_end_log_probs, 1,
                        clamped_gold_passage_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_spans = log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_passage_spans = util.replace_masked_values(
                        log_likelihood_for_passage_spans,
                        gold_passage_span_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_passage_span = util.logsumexp(
                        log_likelihood_for_passage_spans)

                    # span log probabilities
                    log_likelihood_for_passage_span_nums = torch.gather(
                        span_num_log_probs, 1, span_num)
                    log_likelihood_for_passage_span_nums = util.replace_masked_values(
                        log_likelihood_for_passage_span_nums,
                        gold_passage_span_mask[:, :1], -1e7)
                    log_marginal_likelihood_for_passage_span_nums = util.logsumexp(
                        log_likelihood_for_passage_span_nums)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_passage_span +
                        log_marginal_likelihood_for_passage_span_nums)

                elif answering_ability == "question_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_question_span_starts = answer_as_question_spans[:, :,
                                                                         0]
                    gold_question_span_ends = answer_as_question_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_question_span_mask = (gold_question_span_starts !=
                                               -1).long()
                    clamped_gold_question_span_starts = util.replace_masked_values(
                        gold_question_span_starts, gold_question_span_mask, 0)
                    clamped_gold_question_span_ends = util.replace_masked_values(
                        gold_question_span_ends, gold_question_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_span_starts = torch.gather(
                        question_span_start_log_probs, 1,
                        clamped_gold_question_span_starts)
                    log_likelihood_for_question_span_ends = torch.gather(
                        question_span_end_log_probs, 1,
                        clamped_gold_question_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_spans = log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_question_spans = util.replace_masked_values(
                        log_likelihood_for_question_spans,
                        gold_question_span_mask, -1e7)
                    # Shape: (batch_size, )
                    # pylint: disable=invalid-name
                    log_marginal_likelihood_for_question_span = util.logsumexp(
                        log_likelihood_for_question_spans)

                    # question multi span prediction
                    log_likelihood_for_question_span_nums = torch.gather(
                        span_num_log_probs, 1, span_num)
                    log_marginal_likelihood_for_question_span_nums = util.logsumexp(
                        log_likelihood_for_question_span_nums)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_question_span +
                        log_marginal_likelihood_for_question_span_nums)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1)
                                         > 0).float()
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(
                        1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(
                        number_sign_log_probs, 2, gold_add_sub_signs)
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = util.replace_masked_values(
                        log_likelihood_for_number_signs,
                        number_mask.unsqueeze(-1), 0)
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(
                        1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = util.replace_masked_values(
                        log_likelihood_for_add_subs, gold_add_sub_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_add_sub = util.logsumexp(
                        log_likelihood_for_add_subs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = (answer_as_counts != -1).long()
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(
                        answer_as_counts, gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(
                        count_number_log_probs, 1, clamped_gold_counts)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = util.replace_masked_values(
                        log_likelihood_for_counts, gold_count_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_count = util.logsumexp(
                        log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_count)

                else:
                    raise ValueError(
                        f"Unsupported answering ability: {answering_ability}")
            # print(log_marginal_likelihood_list)
            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                marginal_log_likelihood = util.logsumexp(
                    all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]
            output_dict["loss"] = -marginal_log_likelihood.mean()

        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            for i in range(batch_size):
                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[
                        best_answer_ability[i].detach().cpu().numpy()]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                answer_json: Dict[str, Any] = {}

                question_start = 1
                passage_start = len(metadata[i]["question_tokens"]) + 2
                # We did not consider multi-mention answers here
                if predicted_ability_str == "passage_span_extraction":
                    answer_json["answer_type"] = "passage_span"
                    passage_str = metadata[i]['original_passage']
                    offsets = metadata[i]['passage_token_offsets']
                    predicted_answer, predicted_spans = best_answers_extraction(
                        best_passage_span[i], best_span_number[i], passage_str,
                        offsets, passage_start)
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = predicted_spans
                elif predicted_ability_str == "question_span_extraction":
                    answer_json["answer_type"] = "question_span"
                    question_str = metadata[i]['original_question']
                    offsets = metadata[i]['question_token_offsets']
                    predicted_answer, predicted_spans = best_answers_extraction(
                        best_question_span[i], best_span_number[i],
                        question_str, offsets, question_start)
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = predicted_spans
                elif predicted_ability_str == "addition_subtraction":
                    answer_json["answer_type"] = "arithmetic"
                    original_numbers = metadata[i]['original_numbers']
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [
                        sign_remap[it] for it in
                        best_signs_for_numbers[i].detach().cpu().numpy()
                    ]
                    result = sum([
                        sign * number for sign, number in zip(
                            predicted_signs, original_numbers)
                    ])
                    predicted_answer = convert_number_to_str(result)
                    offsets = metadata[i]['passage_token_offsets']
                    number_indices = metadata[i]['number_indices']
                    number_positions = [
                        offsets[index - 1] for index in number_indices
                    ]
                    answer_json['numbers'] = []
                    for offset, value, sign in zip(number_positions,
                                                   original_numbers,
                                                   predicted_signs):
                        answer_json['numbers'].append({
                            'span': offset,
                            'value': value,
                            'sign': sign
                        })
                    if number_indices[-1] == -1:
                        # There is a dummy 0 number at position -1 added in some cases; we are
                        # removing that here.
                        answer_json["numbers"].pop()
                    answer_json["value"] = result
                    answer_json[
                        'number_sign_log_probs'] = number_sign_log_probs[
                            i, :, :].detach().cpu().numpy()

                elif predicted_ability_str == "counting":
                    answer_json["answer_type"] = "count"
                    predicted_count = best_count_number[i].detach().cpu(
                    ).numpy()
                    predicted_answer = str(predicted_count)
                    answer_json["count"] = predicted_count
                else:
                    raise ValueError(
                        f"Unsupported answer ability: {predicted_ability_str}")

                answer_json["predicted_answer"] = predicted_answer
                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(answer_json)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)

            if self.use_gcn:
                output_dict['clamped_number_indices'] = clamped_number_indices
                output_dict['node_weight'] = d_node_weight
        return output_dict
コード例 #2
0
ファイル: dagn.py プロジェクト: yezhuoyang/LogicalReasoning
    def forward(self,
                input_ids: torch.LongTensor,
                attention_mask: torch.LongTensor,

                passage_mask: torch.LongTensor,
                question_mask: torch.LongTensor,

                argument_bpe_ids: torch.LongTensor,
                domain_bpe_ids: torch.LongTensor,
                punct_bpe_ids: torch.LongTensor,

                labels: torch.LongTensor,
                token_type_ids: torch.LongTensor = None,
                ) -> Tuple:

        num_choices = input_ids.shape[1]

        flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
        flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
        flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None

        flat_passage_mask = passage_mask.view(-1, passage_mask.size(-1)) if passage_mask is not None else None
        flat_question_mask = question_mask.view(-1, question_mask.size(-1)) if question_mask is not None else None

        flat_argument_bpe_ids = argument_bpe_ids.view(-1, argument_bpe_ids.size(-1)) if argument_bpe_ids is not None else None
        flat_domain_bpe_ids = domain_bpe_ids.view(-1, domain_bpe_ids.size(-1)) if domain_bpe_ids is not None else None  
        flat_punct_bpe_ids = punct_bpe_ids.view(-1, punct_bpe_ids.size(-1)) if punct_bpe_ids is not None else None

        bert_outputs = self.roberta(flat_input_ids, attention_mask=flat_attention_mask, token_type_ids=None)
        sequence_output = bert_outputs[0]
        pooled_output = bert_outputs[1]  


        if self.use_gcn:
            ''' The GCN branch. Suppose to go back to baseline once remove. '''
            new_punct_id = self.max_rel_id + 1
            new_punct_bpe_ids = new_punct_id * flat_punct_bpe_ids  # punct_id: 1 -> 4. for incorporating with argument_bpe_ids.
            _flat_all_bpe_ids = flat_argument_bpe_ids + new_punct_bpe_ids  # -1:padding, 0:non, 1-3: arg, 4:punct.
            overlapped_punct_argument_mask = (_flat_all_bpe_ids > new_punct_id).long()
            flat_all_bpe_ids = _flat_all_bpe_ids * (1 - overlapped_punct_argument_mask) + flat_argument_bpe_ids * overlapped_punct_argument_mask
            assert flat_argument_bpe_ids.max().item() <= new_punct_id

            # encoded_spans: (bsz x n_choices, n_nodes, embed_size)
            # span_mask: (bsz x n_choices, n_nodes)
            # edges: list[list[int]]
            # node_in_seq_indices: list[list[list[int]]]
            encoded_spans, span_mask, edges, node_in_seq_indices = self.split_into_spans_9(sequence_output,
                                                                                           flat_attention_mask,
                                                                                           flat_all_bpe_ids)

            argument_graph, punctuation_graph = self.get_adjacency_matrices_2(edges, n_nodes=encoded_spans.size(1), device=encoded_spans.device)

            node, node_weight = self._gcn(node=encoded_spans, node_mask=span_mask,
                                          argument_graph=argument_graph,
                                          punctuation_graph=punctuation_graph)  

            gcn_info_vec = self.get_gcn_info_vector(node_in_seq_indices, node,
                                                    size=sequence_output.size(), device=sequence_output.device)  

            gcn_updated_sequence_output = self._gcn_enc(self._gcn_prj_ln(sequence_output + gcn_info_vec))  

            # passage hidden and question hidden
            sequence_h2_weight = self._proj_sequence_h(gcn_updated_sequence_output).squeeze(-1)  
            passage_h2_weight = util.masked_softmax(sequence_h2_weight.float(), flat_passage_mask.float())  
            passage_h2 = util.weighted_sum(gcn_updated_sequence_output, passage_h2_weight)  
            question_h2_weight = util.masked_softmax(sequence_h2_weight.float(), flat_question_mask.float())
            question_h2 = util.weighted_sum(gcn_updated_sequence_output, question_h2_weight)  

            gcn_output_feats = torch.cat([passage_h2, question_h2, gcn_updated_sequence_output[:, 0]], dim=1)  
            gcn_logits = self._proj_span_num(gcn_output_feats)  


        if self.use_pool:
            ''' The baseline branch. The output. '''
            pooled_output = self.dropout(pooled_output)  
            baseline_logits = self.classifier(pooled_output)  


        if self.use_gcn and self.use_pool:
            ''' Merge gcn_logits & baseline_logits. TODO: different way of merging. '''

            if self.merge_type == 1:
                logits = gcn_logits + baseline_logits

            elif self.merge_type == 2:
                pooled_output = self.dropout(pooled_output)
                merged_feats = torch.cat([gcn_updated_sequence_output[:, 0], pooled_output], dim=1)  
                logits = self._proj_gcn_pool_3(merged_feats)  

            elif self.merge_type == 3:
                pooled_output = self.dropout(pooled_output)
                merged_feats = torch.cat([gcn_updated_sequence_output[:, 0], pooled_output,
                                          gcn_updated_sequence_output[:, 0], pooled_output], dim=1)  
                logits = self._proj_gcn_pool_4(merged_feats)  

            elif self.merge_type == 4:
                pooled_output = self.dropout(pooled_output)
                merged_feats = torch.cat([passage_h2, question_h2, pooled_output], dim=1)  
                logits = self._proj_gcn_pool(merged_feats)  

            elif self.merge_type == 5:
                pooled_output = self.dropout(pooled_output)
                merged_feats = torch.cat([passage_h2, question_h2, gcn_updated_sequence_output[:, 0], pooled_output],
                                         dim=1)  
                logits = self._proj_gcn_pool_4(merged_feats)  


        elif self.use_gcn:
            logits = gcn_logits
        elif self.use_pool:
            logits = baseline_logits
        else:
            raise Exception


        reshaped_logits = logits.squeeze(-1).view(-1, num_choices)  
        outputs = (reshaped_logits, )

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(reshaped_logits, labels)
            outputs = (loss,) + outputs

        return outputs