Beispiel #1
0
    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:

        # get the relevant scores for the time step
        class_log_probabilities = state['log_probs'][:,
                                                     state['step_num'][0], :]
        is_wordpiece = (
            1 - state['wordpiece_mask'][:, state['step_num'][0]]).byte()

        # mask illegal BIO transitions
        transitions_mask = torch.cat(
            (torch.ones_like(class_log_probabilities[:, :3]),
             torch.zeros_like(class_log_probabilities[:, -2:])),
            dim=-1).byte()
        transitions_mask[:, 2] &= ((last_predictions == 1) |
                                   (last_predictions == 2))
        transitions_mask[:,
                         1:3] &= ((class_log_probabilities[:, :3]
                                   == 0.0).sum(-1) != 3).unsqueeze(-1).repeat(
                                       1, 2)

        # assuming the wordpiece mask doesn't intersect with the other masks (pad, cls/sep)
        transitions_mask[:, 2] |= is_wordpiece & ((last_predictions == 1) |
                                                  (last_predictions == 2))

        class_log_probabilities = replace_masked_values(
            class_log_probabilities, transitions_mask, -1e7)

        state['step_num'] = state['step_num'].clone() + 1
        return class_log_probabilities, state
Beispiel #2
0
    def _get_combined_likelihood(self, answer_as_list_of_bios, log_probs):
        # answer_as_list_of_bios - Shape: (batch_size, # of correct sequences, seq_length)
        # log_probs - Shape: (batch_size, seq_length, 3)

        # Shape: (batch_size, # of correct sequences, seq_length, 3)
        # duplicate log_probs for each gold bios sequence
        expanded_log_probs = log_probs.unsqueeze(1).expand(
            -1,
            answer_as_list_of_bios.size()[1], -1, -1)

        # get the log-likelihood per each sequence index
        # Shape: (batch_size, # of correct sequences, seq_length)
        log_likelihoods = \
            torch.gather(expanded_log_probs, dim=-1, index=answer_as_list_of_bios.unsqueeze(-1)).squeeze(-1)

        # Shape: (batch_size, # of correct sequences)
        correct_sequences_pad_mask = (answer_as_list_of_bios.sum(-1) >
                                      0).long()

        # Sum the log-likelihoods for each index to get the log-likelihood of the sequence
        # Shape: (batch_size, # of correct sequences)
        sequences_log_likelihoods = log_likelihoods.sum(dim=-1)
        sequences_log_likelihoods = replace_masked_values(
            sequences_log_likelihoods, correct_sequences_pad_mask, -1e7)

        # Sum the log-likelihoods for each sequence to get the marginalized log-likelihood over the correct answers
        log_marginal_likelihood = logsumexp(sequences_log_likelihoods, dim=-1)

        return log_marginal_likelihood
Beispiel #3
0
    def module(self, bert_out, seq_mask=None):
        logits = self.predictor(bert_out)

        if self._use_crf:
            # The mask should not be applied here when using CRF, but should be passed ot the CRF
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        else:
            if seq_mask is not None:
                log_probs = replace_masked_values(
                    torch.nn.functional.log_softmax(logits, dim=-1),
                    seq_mask.unsqueeze(-1), 0.0)
                logits = replace_masked_values(logits, seq_mask.unsqueeze(-1),
                                               -1e7)
            else:
                log_probs = torch.nn.functional.log_softmax(logits)

        return log_probs, logits
Beispiel #4
0
    def log_likelihood(self, gold_labels, log_probs, seq_mask, is_bio_mask,
                       **kwargs):
        # we only want the log probabilities of the gold labels
        # what we get is:
        # log_likelihoods_for_multispan[i,j] == log_probs[i,j, gold_labels[i,j]]
        log_likelihoods_for_multispan = \
            torch.gather(log_probs, dim=-1, index=gold_labels.unsqueeze(-1)).squeeze(-1)

        # Our marginal likelihood is the sum of all the gold label likelihoods, ignoring the
        # padding tokens.
        log_likelihoods_for_multispan = \
            replace_masked_values(log_likelihoods_for_multispan, seq_mask, 0.0)

        log_marginal_likelihood_for_multispan = log_likelihoods_for_multispan.sum(
            dim=-1)

        # For questions without spans, we set their log probabilities to be very small negative value
        log_marginal_likelihood_for_multispan = \
            replace_masked_values(log_marginal_likelihood_for_multispan, is_bio_mask, -1e7)

        return log_marginal_likelihood_for_multispan
Beispiel #5
0
    def prediction(self, log_probs, logits, qp_tokens, p_text, q_text,
                   seq_mask, wordpiece_mask, use_beam_search):

        if use_beam_search:
            top_k_predictions = self._get_top_k_sequences(
                log_probs.unsqueeze(0), wordpiece_mask.unsqueeze(0),
                self._prediction_beam_size)
            predicted_tags = top_k_predictions[0, 0, :]
        else:
            predicted_tags = torch.argmax(logits, dim=-1)
        predicted_tags = replace_masked_values(predicted_tags, seq_mask, 0)

        return MultiSpanHead.decode_spans_from_tags(predicted_tags, qp_tokens,
                                                    p_text, q_text)
Beispiel #6
0
    def log_likelihood(self, gold_labels, log_probs, seq_mask, is_bio_mask,
                       **kwargs):
        logits = kwargs['logits']

        if gold_labels is not None:
            log_denominator = self.crf._input_likelihood(logits, seq_mask)
            log_numerator = self.crf._joint_likelihood(logits, gold_labels,
                                                       seq_mask)

            log_likelihood = log_numerator - log_denominator

            log_likelihood = replace_masked_values(log_likelihood, is_bio_mask,
                                                   -1e7)

            return log_likelihood
    def forward(
            self,  # type: ignore
            input_ids: torch.LongTensor,
            input_mask: torch.LongTensor,
            input_segments: torch.LongTensor,
            passage_mask: torch.LongTensor,
            question_mask: torch.LongTensor,
            number_indices: torch.LongTensor,
            passage_number_order: torch.LongTensor,
            question_number_order: torch.LongTensor,
            question_number_indices: torch.LongTensor,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            answer_as_add_sub_expressions: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            span_num: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        # sequence_output, _, other_sequence_output = self.bert(input_ids, input_segments, input_mask)
        outputs = self.bert(input_ids,
                            attention_mask=input_mask,
                            token_type_ids=input_segments)
        sequence_output = outputs[0]
        sequence_output_list = [item for item in outputs[2][-4:]]

        batch_size = input_ids.size(0)
        if ("passage_span_extraction" in self.answering_abilities or
                "question_span" in self.answering_abilities) and self.use_gcn:
            # M2, M3
            sequence_alg = self._gcn_input_proj(
                torch.cat([sequence_output_list[2], sequence_output_list[3]],
                          dim=2))
            encoded_passage_for_numbers = sequence_alg
            encoded_question_for_numbers = sequence_alg
            # passage number extraction
            real_number_indices = number_indices - 1
            number_mask = (real_number_indices > -1).long()  # ??
            clamped_number_indices = util.replace_masked_values(
                real_number_indices, number_mask, 0)
            encoded_numbers = torch.gather(
                encoded_passage_for_numbers, 1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_passage_for_numbers.size(-1)))

            # question number extraction
            question_number_mask = (question_number_indices > -1).long()
            clamped_question_number_indices = util.replace_masked_values(
                question_number_indices, question_number_mask, 0)
            question_encoded_number = torch.gather(
                encoded_question_for_numbers, 1,
                clamped_question_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_question_for_numbers.size(-1)))

            # graph mask
            number_order = torch.cat(
                (passage_number_order, question_number_order), -1)
            new_graph_mask = number_order.unsqueeze(1).expand(
                batch_size, number_order.size(-1),
                -1) > number_order.unsqueeze(-1).expand(
                    batch_size, -1, number_order.size(-1))
            new_graph_mask = new_graph_mask.long()
            all_number_mask = torch.cat((number_mask, question_number_mask),
                                        dim=-1)
            new_graph_mask = all_number_mask.unsqueeze(
                1) * all_number_mask.unsqueeze(-1) * new_graph_mask

            # iteration
            d_node, q_node, d_node_weight, _ = self._gcn(
                d_node=encoded_numbers,
                q_node=question_encoded_number,
                d_node_mask=number_mask,
                q_node_mask=question_number_mask,
                graph=new_graph_mask)
            gcn_info_vec = torch.zeros((batch_size, sequence_alg.size(1) + 1,
                                        sequence_output_list[-1].size(-1)),
                                       dtype=torch.float,
                                       device=d_node.device)
            clamped_number_indices = util.replace_masked_values(
                real_number_indices, number_mask,
                gcn_info_vec.size(1) - 1)
            gcn_info_vec.scatter_(
                1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, d_node.size(-1)), d_node)
            gcn_info_vec = gcn_info_vec[:, :-1, :]

            sequence_output_list[2] = self._gcn_enc(
                self._proj_ln(sequence_output_list[2] + gcn_info_vec))
            sequence_output_list[0] = self._gcn_enc(
                self._proj_ln0(sequence_output_list[0] + gcn_info_vec))
            sequence_output_list[1] = self._gcn_enc(
                self._proj_ln1(sequence_output_list[1] + gcn_info_vec))
            sequence_output_list[3] = self._gcn_enc(
                self._proj_ln3(sequence_output_list[3] + gcn_info_vec))

        # passage hidden and question hidden
        sequence_h2_weight = self._proj_sequence_h(
            sequence_output_list[2]).squeeze(-1)
        passage_h2_weight = util.masked_softmax(sequence_h2_weight,
                                                passage_mask)
        passage_h2 = util.weighted_sum(sequence_output_list[2],
                                       passage_h2_weight)
        question_h2_weight = util.masked_softmax(sequence_h2_weight,
                                                 question_mask)
        question_h2 = util.weighted_sum(sequence_output_list[2],
                                        question_h2_weight)

        # passage g0, g1, g2
        question_g0_weight = self._proj_sequence_g0(
            sequence_output_list[0]).squeeze(-1)
        question_g0_weight = util.masked_softmax(question_g0_weight,
                                                 question_mask)
        question_g0 = util.weighted_sum(sequence_output_list[0],
                                        question_g0_weight)

        question_g1_weight = self._proj_sequence_g1(
            sequence_output_list[1]).squeeze(-1)
        question_g1_weight = util.masked_softmax(question_g1_weight,
                                                 question_mask)
        question_g1 = util.weighted_sum(sequence_output_list[1],
                                        question_g1_weight)

        question_g2_weight = self._proj_sequence_g2(
            sequence_output_list[2]).squeeze(-1)
        question_g2_weight = util.masked_softmax(question_g2_weight,
                                                 question_mask)
        question_g2 = util.weighted_sum(sequence_output_list[2],
                                        question_g2_weight)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = self._answer_ability_predictor(
                torch.cat([passage_h2, question_h2, sequence_output[:, 0]], 1))
            answer_ability_log_probs = F.log_softmax(answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        real_number_indices = number_indices.squeeze(-1) - 1
        number_mask = (real_number_indices > -1).long()
        clamped_number_indices = util.replace_masked_values(
            real_number_indices, number_mask, 0)
        encoded_passage_for_numbers = torch.cat(
            [sequence_output_list[2], sequence_output_list[3]], dim=-1)
        encoded_numbers = torch.gather(
            encoded_passage_for_numbers, 1,
            clamped_number_indices.unsqueeze(-1).expand(
                -1, -1, encoded_passage_for_numbers.size(-1)))
        number_weight = self._proj_number(encoded_numbers).squeeze(-1)
        number_mask = (number_indices > -1).long()
        number_weight = util.masked_softmax(number_weight, number_mask)
        number_vector = util.weighted_sum(encoded_numbers, number_weight)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(
                torch.cat([
                    number_vector, passage_h2, question_h2, sequence_output[:,
                                                                            0]
                ],
                          dim=1))
            count_number_log_probs = torch.nn.functional.log_softmax(
                count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = torch.gather(
                count_number_log_probs, 1,
                best_count_number.unsqueeze(-1)).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self.
                                                                _counting_index]

        if "passage_span_extraction" in self.answering_abilities or "question_span_extraction" in self.answering_abilities:
            # start 0, 2
            sequence_for_span_start = torch.cat([
                sequence_output_list[2], sequence_output_list[0],
                sequence_output_list[2] * question_g2.unsqueeze(1),
                sequence_output_list[0] * question_g0.unsqueeze(1)
            ],
                                                dim=2)
            sequence_span_start_logits = self._span_start_predictor(
                sequence_for_span_start).squeeze(-1)
            # Shape: (batch_size, passage_length, modeling_dim * 2)
            sequence_for_span_end = torch.cat([
                sequence_output_list[2], sequence_output_list[1],
                sequence_output_list[2] * question_g2.unsqueeze(1),
                sequence_output_list[1] * question_g1.unsqueeze(1)
            ],
                                              dim=2)
            # Shape: (batch_size, passage_length)
            sequence_span_end_logits = self._span_end_predictor(
                sequence_for_span_end).squeeze(-1)
            # Shape: (batch_size, passage_length)

            # span number prediction
            span_num_logits = self._proj_span_num(
                torch.cat([passage_h2, question_h2, sequence_output[:, 0]],
                          dim=1))
            span_num_log_probs = torch.nn.functional.log_softmax(
                span_num_logits, -1)

            best_span_number = torch.argmax(span_num_log_probs, dim=-1)

            if "passage_span_extraction" in self.answering_abilities:
                passage_span_start_log_probs = util.masked_log_softmax(
                    sequence_span_start_logits, passage_mask)
                passage_span_end_log_probs = util.masked_log_softmax(
                    sequence_span_end_logits, passage_mask)

                # Info about the best passage span prediction
                passage_span_start_logits = util.replace_masked_values(
                    sequence_span_start_logits, passage_mask, -1e7)
                passage_span_end_logits = util.replace_masked_values(
                    sequence_span_end_logits, passage_mask, -1e7)
                # Shage: (batch_size, topk, 2)
                best_passage_span = get_best_span(passage_span_start_logits,
                                                  passage_span_end_logits)

            if "question_span_extraction" in self.answering_abilities:
                question_span_start_log_probs = util.masked_log_softmax(
                    sequence_span_start_logits, question_mask)
                question_span_end_log_probs = util.masked_log_softmax(
                    sequence_span_end_logits, question_mask)

                # Info about the best question span prediction
                question_span_start_logits = util.replace_masked_values(
                    sequence_span_start_logits, question_mask, -1e7)
                question_span_end_logits = util.replace_masked_values(
                    sequence_span_end_logits, question_mask, -1e7)
                # Shape: (batch_size, topk, 2)
                best_question_span = get_best_span(question_span_start_logits,
                                                   question_span_end_logits)

        if "addition_subtraction" in self.answering_abilities:
            alg_encoded_numbers = torch.cat([
                encoded_numbers,
                question_h2.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1),
                passage_h2.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1),
                sequence_output[:, 0].unsqueeze(1).repeat(
                    1, encoded_numbers.size(1), 1)
            ], 2)

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(
                alg_encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(
                number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(
                best_signs_for_numbers, number_mask, 0)
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                number_sign_log_probs, 2,
                best_signs_for_numbers.unsqueeze(-1)).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(
                best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)
            best_combination_log_prob = best_signs_log_probs.sum(-1)
            if len(self.answering_abilities) > 1:
                best_combination_log_prob += answer_ability_log_probs[:, self.
                                                                      _addition_subtraction_index]

        output_dict = {}

        # If answer is given, compute the loss.
        if answer_as_passage_spans is not None or answer_as_question_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None:

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
                    gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_passage_span_mask = (gold_passage_span_starts !=
                                              -1).long()
                    clamped_gold_passage_span_starts = util.replace_masked_values(
                        gold_passage_span_starts, gold_passage_span_mask, 0)
                    clamped_gold_passage_span_ends = util.replace_masked_values(
                        gold_passage_span_ends, gold_passage_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_span_starts = torch.gather(
                        passage_span_start_log_probs, 1,
                        clamped_gold_passage_span_starts)
                    log_likelihood_for_passage_span_ends = torch.gather(
                        passage_span_end_log_probs, 1,
                        clamped_gold_passage_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_spans = log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_passage_spans = util.replace_masked_values(
                        log_likelihood_for_passage_spans,
                        gold_passage_span_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_passage_span = util.logsumexp(
                        log_likelihood_for_passage_spans)

                    # span log probabilities
                    log_likelihood_for_passage_span_nums = torch.gather(
                        span_num_log_probs, 1, span_num)
                    log_likelihood_for_passage_span_nums = util.replace_masked_values(
                        log_likelihood_for_passage_span_nums,
                        gold_passage_span_mask[:, :1], -1e7)
                    log_marginal_likelihood_for_passage_span_nums = util.logsumexp(
                        log_likelihood_for_passage_span_nums)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_passage_span +
                        log_marginal_likelihood_for_passage_span_nums)

                elif answering_ability == "question_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_question_span_starts = answer_as_question_spans[:, :,
                                                                         0]
                    gold_question_span_ends = answer_as_question_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_question_span_mask = (gold_question_span_starts !=
                                               -1).long()
                    clamped_gold_question_span_starts = util.replace_masked_values(
                        gold_question_span_starts, gold_question_span_mask, 0)
                    clamped_gold_question_span_ends = util.replace_masked_values(
                        gold_question_span_ends, gold_question_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_span_starts = torch.gather(
                        question_span_start_log_probs, 1,
                        clamped_gold_question_span_starts)
                    log_likelihood_for_question_span_ends = torch.gather(
                        question_span_end_log_probs, 1,
                        clamped_gold_question_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_spans = log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_question_spans = util.replace_masked_values(
                        log_likelihood_for_question_spans,
                        gold_question_span_mask, -1e7)
                    # Shape: (batch_size, )
                    # pylint: disable=invalid-name
                    log_marginal_likelihood_for_question_span = util.logsumexp(
                        log_likelihood_for_question_spans)

                    # question multi span prediction
                    log_likelihood_for_question_span_nums = torch.gather(
                        span_num_log_probs, 1, span_num)
                    log_marginal_likelihood_for_question_span_nums = util.logsumexp(
                        log_likelihood_for_question_span_nums)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_question_span +
                        log_marginal_likelihood_for_question_span_nums)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1)
                                         > 0).float()
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(
                        1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(
                        number_sign_log_probs, 2, gold_add_sub_signs)
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = util.replace_masked_values(
                        log_likelihood_for_number_signs,
                        number_mask.unsqueeze(-1), 0)
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(
                        1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = util.replace_masked_values(
                        log_likelihood_for_add_subs, gold_add_sub_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_add_sub = util.logsumexp(
                        log_likelihood_for_add_subs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = (answer_as_counts != -1).long()
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(
                        answer_as_counts, gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(
                        count_number_log_probs, 1, clamped_gold_counts)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = util.replace_masked_values(
                        log_likelihood_for_counts, gold_count_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_count = util.logsumexp(
                        log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_count)

                else:
                    raise ValueError(
                        f"Unsupported answering ability: {answering_ability}")
            # print(log_marginal_likelihood_list)
            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                marginal_log_likelihood = util.logsumexp(
                    all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]
            output_dict["loss"] = -marginal_log_likelihood.mean()

        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            for i in range(batch_size):
                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[
                        best_answer_ability[i].detach().cpu().numpy()]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                answer_json: Dict[str, Any] = {}

                question_start = 1
                passage_start = len(metadata[i]["question_tokens"]) + 2
                # We did not consider multi-mention answers here
                if predicted_ability_str == "passage_span_extraction":
                    answer_json["answer_type"] = "passage_span"
                    passage_str = metadata[i]['original_passage']
                    offsets = metadata[i]['passage_token_offsets']
                    predicted_answer, predicted_spans = best_answers_extraction(
                        best_passage_span[i], best_span_number[i], passage_str,
                        offsets, passage_start)
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = predicted_spans
                elif predicted_ability_str == "question_span_extraction":
                    answer_json["answer_type"] = "question_span"
                    question_str = metadata[i]['original_question']
                    offsets = metadata[i]['question_token_offsets']
                    predicted_answer, predicted_spans = best_answers_extraction(
                        best_question_span[i], best_span_number[i],
                        question_str, offsets, question_start)
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = predicted_spans
                elif predicted_ability_str == "addition_subtraction":
                    answer_json["answer_type"] = "arithmetic"
                    original_numbers = metadata[i]['original_numbers']
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [
                        sign_remap[it] for it in
                        best_signs_for_numbers[i].detach().cpu().numpy()
                    ]
                    result = sum([
                        sign * number for sign, number in zip(
                            predicted_signs, original_numbers)
                    ])
                    predicted_answer = convert_number_to_str(result)
                    offsets = metadata[i]['passage_token_offsets']
                    number_indices = metadata[i]['number_indices']
                    number_positions = [
                        offsets[index - 1] for index in number_indices
                    ]
                    answer_json['numbers'] = []
                    for offset, value, sign in zip(number_positions,
                                                   original_numbers,
                                                   predicted_signs):
                        answer_json['numbers'].append({
                            'span': offset,
                            'value': value,
                            'sign': sign
                        })
                    if number_indices[-1] == -1:
                        # There is a dummy 0 number at position -1 added in some cases; we are
                        # removing that here.
                        answer_json["numbers"].pop()
                    answer_json["value"] = result
                    answer_json[
                        'number_sign_log_probs'] = number_sign_log_probs[
                            i, :, :].detach().cpu().numpy()

                elif predicted_ability_str == "counting":
                    answer_json["answer_type"] = "count"
                    predicted_count = best_count_number[i].detach().cpu(
                    ).numpy()
                    predicted_answer = str(predicted_count)
                    answer_json["count"] = predicted_count
                else:
                    raise ValueError(
                        f"Unsupported answer ability: {predicted_ability_str}")

                answer_json["predicted_answer"] = predicted_answer
                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(answer_json)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)

            if self.use_gcn:
                output_dict['clamped_number_indices'] = clamped_number_indices
                output_dict['node_weight'] = d_node_weight
        return output_dict
Beispiel #8
0
    def forward(self,
                d_node,
                q_node,
                d_node_mask,
                q_node_mask,
                graph,
                extra_factor=None):

        d_node_len = d_node.size(1)
        q_node_len = q_node.size(1)

        diagmat = torch.diagflat(
            torch.ones(d_node.size(1), dtype=torch.long, device=d_node.device))
        diagmat = diagmat.unsqueeze(0).expand(d_node.size(0), -1, -1)
        dd_graph = d_node_mask.unsqueeze(1) * d_node_mask.unsqueeze(-1) * (
            1 - diagmat)
        dd_graph_left = dd_graph * graph[:, :d_node_len, :d_node_len]
        dd_graph_right = dd_graph * (1 - graph[:, :d_node_len, :d_node_len])

        diagmat = torch.diagflat(
            torch.ones(q_node.size(1), dtype=torch.long, device=q_node.device))
        diagmat = diagmat.unsqueeze(0).expand(q_node.size(0), -1, -1)
        qq_graph = q_node_mask.unsqueeze(1) * q_node_mask.unsqueeze(-1) * (
            1 - diagmat)
        qq_graph_left = qq_graph * graph[:, d_node_len:, d_node_len:]
        qq_graph_right = qq_graph * (1 - graph[:, d_node_len:, d_node_len:])

        dq_graph = d_node_mask.unsqueeze(-1) * q_node_mask.unsqueeze(1)
        dq_graph_left = dq_graph * graph[:, :d_node_len, d_node_len:]
        dq_graph_right = dq_graph * (1 - graph[:, :d_node_len, d_node_len:])

        qd_graph = q_node_mask.unsqueeze(-1) * d_node_mask.unsqueeze(1)
        qd_graph_left = qd_graph * graph[:, d_node_len:, :d_node_len]
        qd_graph_right = qd_graph * (1 - graph[:, d_node_len:, :d_node_len])

        d_node_neighbor_num = dd_graph_left.sum(-1) + dd_graph_right.sum(
            -1) + dq_graph_left.sum(-1) + dq_graph_right.sum(-1)
        d_node_neighbor_num_mask = (d_node_neighbor_num >= 1).long()
        d_node_neighbor_num = util.replace_masked_values(
            d_node_neighbor_num.float(), d_node_neighbor_num_mask, 1)

        q_node_neighbor_num = qq_graph_left.sum(-1) + qq_graph_right.sum(
            -1) + qd_graph_left.sum(-1) + qd_graph_right.sum(-1)
        q_node_neighbor_num_mask = (q_node_neighbor_num >= 1).long()
        q_node_neighbor_num = util.replace_masked_values(
            q_node_neighbor_num.float(), q_node_neighbor_num_mask, 1)

        all_d_weight, all_q_weight = [], []
        for step in range(self.iteration_steps):
            if extra_factor is None:
                d_node_weight = torch.sigmoid(
                    self._node_weight_fc(d_node)).squeeze(-1)
                q_node_weight = torch.sigmoid(
                    self._node_weight_fc(q_node)).squeeze(-1)
            else:
                d_node_weight = torch.sigmoid(
                    self._node_weight_fc(
                        torch.cat((d_node, extra_factor), dim=-1))).squeeze(-1)
                q_node_weight = torch.sigmoid(
                    self._node_weight_fc(
                        torch.cat((q_node, extra_factor), dim=-1))).squeeze(-1)

            all_d_weight.append(d_node_weight)
            all_q_weight.append(q_node_weight)

            self_d_node_info = self._self_node_fc(d_node)
            self_q_node_info = self._self_node_fc(q_node)

            dd_node_info_left = self._dd_node_fc_left(d_node)
            qd_node_info_left = self._qd_node_fc_left(d_node)
            qq_node_info_left = self._qq_node_fc_left(q_node)
            dq_node_info_left = self._dq_node_fc_left(q_node)

            dd_node_weight = util.replace_masked_values(
                d_node_weight.unsqueeze(1).expand(-1, d_node_len, -1),
                dd_graph_left, 0)

            qd_node_weight = util.replace_masked_values(
                d_node_weight.unsqueeze(1).expand(-1, q_node_len, -1),
                qd_graph_left, 0)

            qq_node_weight = util.replace_masked_values(
                q_node_weight.unsqueeze(1).expand(-1, q_node_len, -1),
                qq_graph_left, 0)

            dq_node_weight = util.replace_masked_values(
                q_node_weight.unsqueeze(1).expand(-1, d_node_len, -1),
                dq_graph_left, 0)

            dd_node_info_left = torch.matmul(dd_node_weight, dd_node_info_left)
            qd_node_info_left = torch.matmul(qd_node_weight, qd_node_info_left)
            qq_node_info_left = torch.matmul(qq_node_weight, qq_node_info_left)
            dq_node_info_left = torch.matmul(dq_node_weight, dq_node_info_left)

            dd_node_info_right = self._dd_node_fc_right(d_node)
            qd_node_info_right = self._qd_node_fc_right(d_node)
            qq_node_info_right = self._qq_node_fc_right(q_node)
            dq_node_info_right = self._dq_node_fc_right(q_node)

            dd_node_weight = util.replace_masked_values(
                d_node_weight.unsqueeze(1).expand(-1, d_node_len, -1),
                dd_graph_right, 0)

            qd_node_weight = util.replace_masked_values(
                d_node_weight.unsqueeze(1).expand(-1, q_node_len, -1),
                qd_graph_right, 0)

            qq_node_weight = util.replace_masked_values(
                q_node_weight.unsqueeze(1).expand(-1, q_node_len, -1),
                qq_graph_right, 0)

            dq_node_weight = util.replace_masked_values(
                q_node_weight.unsqueeze(1).expand(-1, d_node_len, -1),
                dq_graph_right, 0)

            dd_node_info_right = torch.matmul(dd_node_weight,
                                              dd_node_info_right)
            qd_node_info_right = torch.matmul(qd_node_weight,
                                              qd_node_info_right)
            qq_node_info_right = torch.matmul(qq_node_weight,
                                              qq_node_info_right)
            dq_node_info_right = torch.matmul(dq_node_weight,
                                              dq_node_info_right)

            agg_d_node_info = (
                dd_node_info_left + dd_node_info_right + dq_node_info_left +
                dq_node_info_right) / d_node_neighbor_num.unsqueeze(-1)
            agg_q_node_info = (
                qq_node_info_left + qq_node_info_right + qd_node_info_left +
                qd_node_info_right) / q_node_neighbor_num.unsqueeze(-1)

            d_node = F.relu(self_d_node_info + agg_d_node_info)
            q_node = F.relu(self_q_node_info + agg_q_node_info)

        all_d_weight = [weight.unsqueeze(1) for weight in all_d_weight]
        all_q_weight = [weight.unsqueeze(1) for weight in all_q_weight]

        all_d_weight = torch.cat(all_d_weight, dim=1)
        all_q_weight = torch.cat(all_q_weight, dim=1)

        return d_node, q_node, all_d_weight, all_q_weight  # d_node_weight, q_node_weight
Beispiel #9
0
    def forward(self,
                node,  
                node_mask,  
                argument_graph,  
                punctuation_graph,  
                extra_factor=None):
        ''' '''
        '''
        Current: 2 relation patterns.
            - argument edge. (most of them are causal relations)
            - punctuation edges. (including periods and commas)
        '''

        node_len = node.size(1)

        diagmat = torch.diagflat(torch.ones(node.size(1), dtype=torch.long,
                                            device=node.device))  
        diagmat = diagmat.unsqueeze(0).expand(node.size(0), -1, -1)  
        dd_graph = node_mask.unsqueeze(1) * node_mask.unsqueeze(-1) * (1 - diagmat)  

        graph_argument = dd_graph * argument_graph
        graph_punctuation = dd_graph * punctuation_graph

        node_neighbor_num = graph_argument.sum(-1) + graph_punctuation.sum(-1)
        node_neighbor_num_mask = (node_neighbor_num >= 1).long()
        node_neighbor_num = util.replace_masked_values(node_neighbor_num.float(), node_neighbor_num_mask, 1)

        all_weight = []
        for step in range(self.iteration_steps):

            ''' (1) Node Relatedness Measure '''
            if extra_factor is None:
                d_node_weight = torch.sigmoid(self._node_weight_fc(node)).squeeze(
                    -1)  
            else:
                d_node_weight = torch.sigmoid(self._node_weight_fc(torch.cat((node, extra_factor), dim=-1))).squeeze(
                    -1)  

            all_weight.append(d_node_weight)  

            self_node_info = self._self_node_fc(node)

            ''' (2) Message Propagation (each relation type) '''
            
            node_info_argument = self._node_fc_argument(node)
            node_weight = util.replace_masked_values(
                d_node_weight.unsqueeze(1).expand(-1, node_len, -1),
                graph_argument,
                0)  
            node_info_argument = torch.matmul(node_weight, node_info_argument)

            
            node_info_punctuation = self._node_fc_punctuation(node)
            node_weight = util.replace_masked_values(
                d_node_weight.unsqueeze(1).expand(-1, node_len, -1),
                graph_punctuation,
                0)  
            node_info_punctuation = torch.matmul(node_weight, node_info_punctuation)

            agg_node_info = (node_info_argument + node_info_punctuation) / node_neighbor_num.unsqueeze(-1)

            ''' (3) Node Representation Update '''
            node = F.relu(self_node_info + agg_node_info)

        all_weight = [weight.unsqueeze(1) for weight in all_weight]
        all_weight = torch.cat(all_weight, dim=1)

        return node, all_weight
Beispiel #10
0
    def log_likelihood(self, answer_as_text_to_disjoint_bios,
                       answer_as_list_of_bios, span_bio_labels, log_probs,
                       logits, seq_mask, wordpiece_mask, is_bio_mask,
                       **kwargs):
        # answer_as_text_to_disjoint_bios - Shape: (batch_size, # of text answers, # of spans a for text answer, seq_length)
        # answer_as_list_of_bios - Shape: (batch_size, # of correct sequences, seq_length)
        # log_probs - Shape: (batch_size, seq_length, 3)
        # seq_mask - Shape: (batch_size, seq_length)

        # Generate most likely correct predictions
        if self._use_crf:
            raise NotImplementedError
        else:
            with torch.no_grad():
                answer_as_list_of_bios = answer_as_list_of_bios * seq_mask.unsqueeze(
                    1)
                if answer_as_text_to_disjoint_bios.sum() > 0:
                    full_bio = span_bio_labels

                    if self._generation_top_k > 0:
                        most_likely_predictions = self._get_top_k_sequences(
                            log_probs, wordpiece_mask, self._generation_top_k)

                        most_likely_predictions = most_likely_predictions * seq_mask.unsqueeze(
                            1)

                        generated_list_of_bios = self._filter_correct_predictions(
                            most_likely_predictions,
                            answer_as_text_to_disjoint_bios, full_bio)

                        is_pregenerated_answer_format_mask = (
                            answer_as_list_of_bios.sum((1, 2)) >
                            0).unsqueeze(-1).unsqueeze(-1).long()
                        list_of_bios = torch.cat(
                            (answer_as_list_of_bios,
                             (generated_list_of_bios *
                              (1 - is_pregenerated_answer_format_mask))),
                            dim=1)

                        list_of_bios = self._add_full_bio(
                            list_of_bios, full_bio)
                    else:
                        is_pregenerated_answer_format_mask = (
                            answer_as_list_of_bios.sum((1, 2)) > 0).long()
                        list_of_bios = torch.cat(
                            (answer_as_list_of_bios,
                             (full_bio *
                              (1 - is_pregenerated_answer_format_mask
                               ).unsqueeze(-1)).unsqueeze(1)),
                            dim=1)
                else:
                    list_of_bios = answer_as_list_of_bios

        ### Calculate log-likelihood from list_of_bios
        if self._use_crf:
            raise NotImplementedError
        else:
            log_marginal_likelihood_for_multispan = self._get_combined_likelihood(
                list_of_bios, log_probs)

        # For questions without spans, we set their log probabilities to be very small negative value
        log_marginal_likelihood_for_multispan = \
            replace_masked_values(log_marginal_likelihood_for_multispan, is_bio_mask, -1e7)

        return log_marginal_likelihood_for_multispan
Beispiel #11
0
    def prediction(self, log_probs, logits, qp_tokens, p_text, q_text, mask):
        predicted_tags = torch.argmax(logits, dim=-1)
        predicted_tags = replace_masked_values(predicted_tags, mask, 0)

        return MultiSpanHead.decode_spans_from_tags(predicted_tags, qp_tokens,
                                                    p_text, q_text)