Exemple #1
0
    def test_get_best_span(self):
        span_begin_probs = torch.FloatTensor([[0.1, 0.3, 0.05, 0.3,
                                               0.25]]).log()
        span_end_probs = torch.FloatTensor([[0.65, 0.05, 0.2, 0.05,
                                             0.05]]).log()
        begin_end_idxs = get_best_span(span_begin_probs, span_end_probs)
        assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]])

        # When we were using exclusive span ends, this was an edge case of the dynamic program.
        # We're keeping the test to make sure we get it right now, after the switch in inclusive
        # span end.  The best answer is (1, 1).
        span_begin_probs = torch.FloatTensor([[0.4, 0.5, 0.1]]).log()
        span_end_probs = torch.FloatTensor([[0.3, 0.6, 0.1]]).log()
        begin_end_idxs = get_best_span(span_begin_probs, span_end_probs)
        assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 1]])

        # Another instance that used to be an edge case.
        span_begin_probs = torch.FloatTensor([[0.8, 0.1, 0.1]]).log()
        span_end_probs = torch.FloatTensor([[0.8, 0.1, 0.1]]).log()
        begin_end_idxs = get_best_span(span_begin_probs, span_end_probs)
        assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]])

        span_begin_probs = torch.FloatTensor([[0.1, 0.2, 0.05, 0.3,
                                               0.25]]).log()
        span_end_probs = torch.FloatTensor([[0.1, 0.2, 0.5, 0.05, 0.15]]).log()
        begin_end_idxs = get_best_span(span_begin_probs, span_end_probs)
        assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 2]])
Exemple #2
0
    def _passage_span_module(self, passage_out, passage_mask):
        # Shape: (batch_size, passage_length)
        passage_span_start_logits = self._passage_span_start_predictor(
            passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_end_logits = self._passage_span_end_predictor(
            passage_out).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_start_log_probs = util.masked_log_softmax(
            passage_span_start_logits, passage_mask)
        passage_span_end_log_probs = util.masked_log_softmax(
            passage_span_end_logits, passage_mask)

        # Info about the best passage span prediction
        passage_span_start_logits = util.replace_masked_values(
            passage_span_start_logits, passage_mask, -1e7)
        passage_span_end_logits = util.replace_masked_values(
            passage_span_end_logits, passage_mask, -1e7)

        # Shape: (batch_size, 2)
        best_passage_span = get_best_span(passage_span_start_logits,
                                          passage_span_end_logits)
        return passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span
Exemple #3
0
    def _question_span_module(self, passage_vector, question_out,
                              question_mask):
        # Shape: (batch_size, question_length)
        encoded_question_for_span_prediction = \
            torch.cat([question_out,
                       passage_vector.unsqueeze(1).repeat(1, question_out.size(1), 1)], -1)
        question_span_start_logits = \
            self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
        # Shape: (batch_size, question_length)
        question_span_end_logits = \
            self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
        question_span_start_log_probs = util.masked_log_softmax(
            question_span_start_logits, question_mask)
        question_span_end_log_probs = util.masked_log_softmax(
            question_span_end_logits, question_mask)

        # Info about the best question span prediction
        question_span_start_logits = \
            util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
        question_span_end_logits = \
            util.replace_masked_values(question_span_end_logits, question_mask, -1e7)

        # Shape: (batch_size, 2)
        best_question_span = get_best_span(question_span_start_logits,
                                           question_span_end_logits)
        return question_span_start_log_probs, question_span_end_log_probs, best_question_span
Exemple #4
0
def ensemble(subresults: List[Dict[str, torch.Tensor]]) -> torch.Tensor:
    """
    Identifies the best prediction given the results from the submodels.

    Parameters
    ----------
    subresults : List[Dict[str, torch.Tensor]]
        Results of each submodel.

    Returns
    -------
    The index of the best submodel.
    """

    # Choose the highest average confidence span.

    span_start_probs = sum(subresult['span_start_probs']
                           for subresult in subresults) / len(subresults)
    span_end_probs = sum(subresult['span_end_probs']
                         for subresult in subresults) / len(subresults)
    return get_best_span(span_start_probs.log(),
                         span_end_probs.log())  # type: ignore
Exemple #5
0
    def forward(
        self,
        passage_attention: torch.Tensor,
        passage_lengths: List[int],
        answer_as_passage_spans: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        batch_size, max_passage_length = passage_attention.size()
        passage_mask = passage_attention.new_zeros(batch_size, max_passage_length)
        for i, passage_length in enumerate(passage_lengths):
            passage_mask[i, 0:passage_length] = 1.0

        answer_as_passage_spans = answer_as_passage_spans.long()

        passage_attention = passage_attention * passage_mask

        if self._scaling:
            scaled_attentions = [passage_attention * sf for sf in self.scaling_vals]
            passage_attention_input = torch.stack(scaled_attentions, dim=2)
        else:
            passage_attention_input = passage_attention.unsqueeze(2)

        # Shape: (batch_size, passage_length, span_rnn_hsize)
        passage_span_logits_repr = self.passage_attention_to_span(passage_attention_input, passage_mask)

        # Shape: (batch_size, passage_length, 2)
        passage_span_logits = self.passage_startend_predictor(passage_span_logits_repr)

        # Shape: (batch_size, passage_length)
        span_start_logits = passage_span_logits[:, :, 0]
        span_end_logits = passage_span_logits[:, :, 1]

        span_start_logits = allenutil.replace_masked_values(span_start_logits, passage_mask, -1e32)
        span_end_logits = allenutil.replace_masked_values(span_end_logits, passage_mask, -1e32)

        span_start_log_probs = allenutil.masked_log_softmax(span_start_logits, passage_mask)
        span_end_log_probs = allenutil.masked_log_softmax(span_end_logits, passage_mask)

        span_start_log_probs = allenutil.replace_masked_values(span_start_log_probs, passage_mask, -1e32)
        span_end_log_probs = allenutil.replace_masked_values(span_end_log_probs, passage_mask, -1e32)

        # Loss computation
        batch_likelihood = 0
        output_dict = {}
        for i in range(batch_size):
            log_likelihood = self._get_span_answer_log_prob(
                answer_as_spans=answer_as_passage_spans[i],
                span_log_probs=(span_start_log_probs[i], span_end_log_probs[i]),
            )

            best_span = get_best_span(
                span_start_logits=span_start_log_probs[i].unsqueeze(0),
                span_end_logits=span_end_log_probs[i].unsqueeze(0),
            ).squeeze(0)

            correct_start, correct_end = False, False

            if best_span[0] == answer_as_passage_spans[i][0][0]:
                self.start_acc_metric(1)
                correct_start = True
            else:
                self.start_acc_metric(0)

            if best_span[1] == answer_as_passage_spans[i][0][1]:
                self.end_acc_metric(1)
                correct_end = True
            else:
                self.end_acc_metric(0)

            if correct_start and correct_end:
                self.span_acc_metric(1)
            else:
                self.span_acc_metric(0)

            batch_likelihood += log_likelihood

        loss = -1.0 * batch_likelihood

        batch_loss = loss / batch_size
        output_dict["loss"] = batch_loss

        return output_dict
Exemple #6
0
    def _get_best_spans(
        batch_denotations,
        batch_denotation_types,
        question_char_offsets,
        question_strs,
        passage_char_offsets,
        passage_strs,
        *args,
    ):
        """ For all SpanType denotations, get the best span

        Parameters:
        ----------
        batch_denotations: List[List[Any]]
        batch_denotation_types: List[List[str]]
        """

        (question_num_tokens, passage_num_tokens, question_mask_aslist, passage_mask_aslist) = args

        batch_best_spans = []
        batch_predicted_answers = []

        for instance_idx in range(len(batch_denotations)):
            instance_prog_denotations = batch_denotations[instance_idx]
            instance_prog_types = batch_denotation_types[instance_idx]

            instance_best_spans = []
            instance_predicted_ans = []

            for denotation, progtype in zip(instance_prog_denotations, instance_prog_types):
                # if progtype == "QuestionSpanAnswwer":
                # Distinction between QuestionSpanAnswer and PassageSpanAnswer is not needed currently,
                # since both classes store the start/end logits as a tuple
                # Shape: (2, )
                best_span = get_best_span(
                    span_start_logits=denotation._value[0].unsqueeze(0),
                    span_end_logits=denotation._value[1].unsqueeze(0),
                ).squeeze(0)
                instance_best_spans.append(best_span)

                predicted_span = tuple(best_span.detach().cpu().numpy())
                if progtype == "QuestionSpanAnswer":
                    try:
                        start_offset = question_char_offsets[instance_idx][predicted_span[0]][0]
                        end_offset = question_char_offsets[instance_idx][predicted_span[1]][1]
                        predicted_answer = question_strs[instance_idx][start_offset:end_offset]
                    except:
                        print()
                        print(f"PredictedSpan: {predicted_span}")
                        print(f"Question numtoksn: {question_num_tokens[instance_idx]}")
                        print(f"QuesMaskLen: {question_mask_aslist[instance_idx].size()}")
                        print(f"StartLogProbs:{denotation._value[0]}")
                        print(f"EndLogProbs:{denotation._value[1]}")
                        print(f"LenofOffsets: {len(question_char_offsets[instance_idx])}")
                        print(f"QuesStrLen: {len(question_strs[instance_idx])}")

                elif progtype == "PassageSpanAnswer":
                    try:
                        start_offset = passage_char_offsets[instance_idx][predicted_span[0]][0]
                        end_offset = passage_char_offsets[instance_idx][predicted_span[1]][1]
                        predicted_answer = passage_strs[instance_idx][start_offset:end_offset]
                    except:
                        print()
                        print(f"PredictedSpan: {predicted_span}")
                        print(f"Passagenumtoksn: {passage_num_tokens[instance_idx]}")
                        print(f"PassageMaskLen: {passage_mask_aslist[instance_idx].size()}")
                        print(f"LenofOffsets: {len(passage_char_offsets[instance_idx])}")
                        print(f"PassageStrLen: {len(passage_strs[instance_idx])}")
                else:
                    raise NotImplementedError

                instance_predicted_ans.append(predicted_answer)

            batch_best_spans.append(instance_best_spans)
            batch_predicted_answers.append(instance_predicted_ans)

        return batch_best_spans, batch_predicted_answers
Exemple #7
0
    def forward(self,  # type: ignore
                passage_question: Dict[str, torch.LongTensor],
#                passage: Dict[str, torch.LongTensor],
                number_indices: torch.LongTensor,
                answer_type = None,
#                answer_as_passage_spans: torch.LongTensor = None,
                answer_as_spans: torch.LongTensor = None,
                answer_as_add_sub_expressions: torch.LongTensor = None,
                answer_as_counts: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:


        passage_question_mask = passage_question["mask"].float()
        embedded_passage_question = self._dropout(self._text_field_embedder(passage_question))#把id转化为vector,加dropout

        batch_size = embedded_passage_question.size(0)

        #bzw加的
        encoded_passage_question = embedded_passage_question
        """
        passage_vactor 用 [CLS]对应的代替
        """

        passage_question_vector = encoded_passage_question[:,0] 

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(passage_question_vector)
            answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1)
            #best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(passage_question_vector)
            count_number_log_probs = torch.nn.functional.log_softmax(count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = \
                torch.gather(count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self._counting_index]

        if "span_extraction" in self.answering_abilities:
            # Shape: (batch_size, passage_length)
            span_start_logits = self._span_start_predictor(encoded_passage_question).squeeze(-1)
            # Shape: (batch_size, passage_length)
            span_end_logits = self._span_end_predictor(encoded_passage_question).squeeze(-1)
            # Shape: (batch_size, passage_length)
            span_start_log_probs = util.masked_log_softmax(span_start_logits, passage_question_mask)
            span_end_log_probs = util.masked_log_softmax(span_end_logits, passage_question_mask)

            # Info about the best passage span prediction
            span_start_logits = util.replace_masked_values(span_start_logits, passage_question_mask, -1e7)#把mask的结果用-1e7代替
            span_end_logits = util.replace_masked_values(span_end_logits, passage_question_mask, -1e7)
            # Shape: (batch_size, 2)
            best_span = get_best_span(span_start_logits, span_end_logits)
            # Shape: (batch_size, 2)
            best_start_log_probs = \
                torch.gather(span_start_log_probs, 1, best_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_end_log_probs = \
                torch.gather(span_end_log_probs, 1, best_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_span_log_prob = best_start_log_probs + best_end_log_probs
            if len(self.answering_abilities) > 1:
                best_span_log_prob += answer_ability_log_probs[:, self._span_extraction_index]


        if "addition_subtraction" in self.answering_abilities:
            # Shape: (batch_size, # of numbers in the passage)
            number_indices = number_indices.squeeze(-1)
            number_mask = (number_indices != -1).long()


            clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0)
            #encoded_passage_for_numbers = torch.cat([modeled_passage_list[0], modeled_passage_list[3]], dim=-1)
            # Shape: (batch_size, # of numbers in the passage, encoding_dim)
            encoded_numbers = torch.gather(
                    encoded_passage_question,
                    1,
                    clamped_number_indices.unsqueeze(-1).expand(-1, -1, encoded_passage_question.size(-1)))
           

            #self._external_number_embedding = self._external_number_embedding.cuda(device)

            #encoded_numbers = self.self_attention(encoded_numbers,number_mask)
            encoded_numbers = self.Concat_attention(encoded_numbers,passage_question_vector,number_mask)
            # Shape: (batch_size, # of numbers in the passage)
            #encoded_numbers = torch.cat(
            #        [encoded_numbers, passage_question_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(best_signs_for_numbers, number_mask, 0)
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                    number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)

        
            if len(self.answering_abilities) > 1:
                # batch_size
                best_combination_log_prob = best_signs_log_probs.sum(-1)
                best_combination_log_prob += answer_ability_log_probs[:, self._addition_subtraction_index]

            
        best_answer_ability = torch.argmax(torch.stack([best_span_log_prob,best_combination_log_prob,best_count_log_prob],-1),1) 


        output_dict = {}

        # If answer is given, compute the loss.
        if answer_as_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None:

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_span_starts = answer_as_spans[:, :, 0]
                    gold_span_ends = answer_as_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_span_mask = (gold_span_starts != -1).long()
                    clamped_gold_span_starts = \
                        util.replace_masked_values(gold_span_starts, gold_span_mask, 0)
                    clamped_gold_span_ends = \
                        util.replace_masked_values(gold_span_ends, gold_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_span_starts = \
                        torch.gather(span_start_log_probs, 1, clamped_gold_span_starts)
                    log_likelihood_for_span_ends = \
                        torch.gather(span_end_log_probs, 1, clamped_gold_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_spans = \
                        log_likelihood_for_span_starts + log_likelihood_for_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_spans = \
                        util.replace_masked_values(log_likelihood_for_spans, gold_span_mask, -1e7)
                    # Shape: (batch_size, )
#                    log_marginal_likelihood_for_span = torch.sum(log_likelihood_for_spans,-1) 
                    log_marginal_likelihood_for_span = util.logsumexp(log_likelihood_for_spans)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_span)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1) > 0).float()
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(number_sign_log_probs, 2, gold_add_sub_signs)
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = \
                        util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0)
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = \
                        util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7)
                    # Shape: (batch_size, )

                    #log_marginal_likelihood_for_add_sub =  torch.sum(log_likelihood_for_add_subs,-1)
                    #log_marginal_likelihood_for_add_sub = util.logsumexp(log_likelihood_for_add_subs)
                    #log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub)
                    


                    
                    log_marginal_likelihood_for_add_sub = util.logsumexp(log_likelihood_for_add_subs)


                    #log_marginal_likelihood_for_external = util.logsumexp(log_likelihood_for_externals)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = (answer_as_counts != -1).long()
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(answer_as_counts, gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(count_number_log_probs, 1, clamped_gold_counts)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = \
                        util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7)
                    # Shape: (batch_size, )
                    #log_marginal_likelihood_for_count =  torch.sum(log_likelihood_for_counts,-1)
                    log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_count)

                else:
                    raise ValueError(f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1)
                loss_for_type = -(torch.sum(answer_ability_log_probs*answer_type,-1)).mean()
                loss_for_answer = -(torch.sum(all_log_marginal_likelihoods,-1)).mean()
                loss = loss_for_type+loss_for_answer
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]
                loss =  - marginal_log_likelihood.mean()
            output_dict["loss"] = loss

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            passage_question_tokens = []
            for i in range(batch_size):
                passage_question_tokens.append(metadata[i]['passage_question_tokens'])

                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[best_answer_ability[i].detach().cpu().numpy()]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                answer_json: Dict[str, Any] = {}

                # We did not consider multi-mention answers here
                if predicted_ability_str == "span_extraction":
                    answer_json["answer_type"] = "span"
                    passage_question_token = metadata[i]['passage_question_tokens']
                    #offsets = metadata[i]['passage_token_offsets']
                    predicted_span = tuple(best_span[i].detach().cpu().numpy())
                    start_offset = predicted_span[0]
                    end_offset = predicted_span[1]
                    predicted_answer = " ".join([token for token in passage_question_token[start_offset:end_offset+1] if token!="[SEP]"]).replace(" ##","")
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif predicted_ability_str == "counting":
                    answer_json["answer_type"] = "count"
                    predicted_count = best_count_number[i].detach().cpu().numpy()
                    predicted_answer = str(predicted_count)
                    answer_json["count"] = predicted_count
                elif predicted_ability_str == "addition_subtraction":
                    answer_json["answer_type"] = "arithmetic"
                    original_numbers = metadata[i]['original_numbers']
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy()]
                    result=0
                    for j,number in enumerate(original_numbers):
                        sign = predicted_signs[j]
                        if sign!=0:
                            result += sign * number
                    
                    predicted_answer = str(result)
                    #offsets = metadata[i]['passage_token_offsets']
                    number_indices = metadata[i]['number_indices']
                    #number_positions = [offsets[index] for index in number_indices]
                    answer_json['numbers'] = []
                    for indice, value, sign in zip(number_indices, original_numbers, predicted_signs):
                        answer_json['numbers'].append({'span': indice, 'value': str(value), 'sign': sign})
                    if number_indices[-1] == -1:
                        # There is a dummy 0 number at position -1 added in some cases; we are
                        # removing that here.
                        answer_json["numbers"].pop()
                    answer_json["value"] = str(result)
                else:
                    raise ValueError(f"Unsupported answer ability: {predicted_ability_str}")

                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(answer_json)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)
            # This is used for the demo.
            #output_dict["passage_question_attention"] = passage_question_attention
            output_dict["passage_question_tokens"] = passage_question_tokens
            #output_dict["passage_tokens"] = passage_tokens
        return output_dict
Exemple #8
0
    def forward(self,
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                passageidx2numberidx: torch.LongTensor,
                passage_number_values: List[int],
                passageidx2dateidx: torch.LongTensor,
                passage_date_values: List[List[Date]],
                actions: List[List[ProductionRule]],
                datecomp_ques_event_date_groundings: List[Tuple[List[int], List[int]]] = None,
                numcomp_qspan_num_groundings: List[Tuple[List[int], List[int]]] = None,
                strongly_supervised: List[bool] = None,
                qtypes: List[str] = None,
                qattn_supervision: torch.FloatTensor = None,
                answer_types: List[str] = None,
                answer_as_passage_spans: torch.LongTensor = None,
                answer_as_question_spans: torch.LongTensor = None,
                epoch_num: List[int] = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size = len(actions)

        if epoch_num is not None:
            # epoch_num in allennlp starts from 0
            epoch = epoch_num[0] + 1
        else:
            epoch = None

        question_mask = allenutil.get_text_field_mask(question).float()
        passage_mask = allenutil.get_text_field_mask(passage).float()
        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))

        projected_embedded_question = self._encoding_proj_layer(embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(embedded_passage)

        encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = allenutil.masked_softmax(
            passage_question_similarity,
            question_mask,
            memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = allenutil.weighted_sum(encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = allenutil.masked_softmax(
            passage_question_similarity.transpose(1, 2),
            passage_mask,
            memory_efficient=True)
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = allenutil.weighted_sum(encoded_passage, attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat([encoded_passage, passage_question_vectors,
                       encoded_passage * passage_question_vectors,
                       encoded_passage * passage_passage_vectors],
                      dim=-1)
        )

        modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]

        for _ in range(3):
            modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = allenutil.replace_masked_values(span_start_logits, passage_mask, -1e32)
        span_end_logits = allenutil.replace_masked_values(span_end_logits, passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        span_start_logprob = allenutil.masked_log_softmax(span_start_logits, mask=passage_mask, dim=-1)
        span_end_logprob = allenutil.masked_log_softmax(span_end_logits, mask=passage_mask, dim=-1)
        span_start_logprob = allenutil.replace_masked_values(span_start_logprob, passage_mask, -1e32)
        span_end_logprob = allenutil.replace_masked_values(span_end_logprob, passage_mask, -1e32)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }


        if answer_types is not None:
            loss = 0
            for i in range(batch_size):
                loss += self._get_span_answer_log_prob(answer_as_spans=answer_as_passage_spans[i],
                                                       span_log_probs=(span_start_logprob[i], span_end_logprob[i]))

            loss = (-1.0 * loss) / batch_size

            self.modelloss_metric(myutils.tocpuNPList(loss)[0])
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['passage_token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_annotations = metadata[i].get('answer_annotation')
                self._drop_metrics(best_span_string, [answer_annotations])

        output_dict.update({'metadata': metadata})

        return output_dict
Exemple #9
0
    def forward(
            self,  # type: ignore
            bert_input: Dict[str, torch.LongTensor],
            sim_bert_input: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            label: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        if self.use_scenario_encoding:
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_bert_input_token_labels_wp = sim_bert_input[
                'scenario_gold_encoding']
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_bert_output_wp = self._sim_text_field_embedder(sim_bert_input)
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_input_mask_wp = (sim_bert_input['bert'] != 0).float()
            # Shape: (batch_size, sim_bert_input_len_wp)
            sim_passage_mask_wp = sim_input_mask_wp - sim_bert_input[
                'bert-type-ids'].float()  # works only with one [SEP]
            # Shape: (batch_size, sim_bert_input_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_bert_output_wp * sim_passage_mask_wp.unsqueeze(
                2)
            # Shape: (batch_size, passage_len_wp, embedding_dim)
            sim_passage_representation_wp = sim_passage_representation_wp[:,
                                                                          sim_passage_mask_wp
                                                                          .sum(
                                                                              dim
                                                                              =0
                                                                          ) >
                                                                          0, :]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_token_labels_wp = sim_bert_input_token_labels_wp[:,
                                                                         sim_passage_mask_wp
                                                                         .sum(
                                                                             dim
                                                                             =0
                                                                         ) > 0]
            # Shape: (batch_size, passage_len_wp)
            sim_passage_mask_wp = sim_passage_mask_wp[:,
                                                      sim_passage_mask_wp.sum(
                                                          dim=0) > 0]

            # Shape: (batch_size, passage_len_wp, 4)
            sim_token_logits_wp = self._sim_token_label_predictor(
                sim_passage_representation_wp)

            if span_start is not None:  # during training and validation
                class_weights = torch.tensor(self.sim_class_weights,
                                             device=sim_token_logits_wp.device,
                                             dtype=torch.float)
                sim_loss = cross_entropy(sim_token_logits_wp.view(-1, 4),
                                         sim_passage_token_labels_wp.view(-1),
                                         ignore_index=0,
                                         weight=class_weights)
                self._sim_loss_metric(sim_loss.item())
                self._sim_yes_f1(sim_token_logits_wp,
                                 sim_passage_token_labels_wp,
                                 sim_passage_mask_wp)
                self._sim_no_f1(sim_token_logits_wp,
                                sim_passage_token_labels_wp,
                                sim_passage_mask_wp)
                if self.sim_pretraining:
                    return {'loss': sim_loss}

            if not self.sim_pretraining:
                # Shape: (batch_size, passage_len_wp)
                bert_input['scenario_encoding'] = (sim_token_logits_wp.argmax(
                    dim=2)) * sim_passage_mask_wp.long()
                # Shape: (batch_size, bert_input_len_wp)
                bert_input_wp_len = bert_input['history_encoding'].size(1)
                if bert_input['scenario_encoding'].size(1) > bert_input_wp_len:
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = bert_input[
                        'scenario_encoding'][:, :bert_input_wp_len]
                else:
                    batch_size = bert_input['scenario_encoding'].size(0)
                    difference = bert_input_wp_len - bert_input[
                        'scenario_encoding'].size(1)
                    zeros = torch.zeros(
                        batch_size,
                        difference,
                        dtype=bert_input['scenario_encoding'].dtype,
                        device=bert_input['scenario_encoding'].device)
                    # Shape: (batch_size, bert_input_len_wp)
                    bert_input['scenario_encoding'] = torch.cat(
                        [bert_input['scenario_encoding'], zeros], dim=1)

        # Shape: (batch_size, bert_input_len + 1, embedding_dim)
        bert_output = self._text_field_embedder(bert_input)
        # Shape: (batch_size, embedding_dim)
        pooled_output = bert_output[:, 0]
        # Shape: (batch_size, bert_input_len, embedding_dim)
        bert_output = bert_output[:, 1:, :]
        # Shape: (batch_size, passage_len, embedding_dim), (batch_size, passage_len)
        passage_representation, passage_mask = self.get_passage_representation(
            bert_output, bert_input)

        # Shape: (batch_size, 4)
        action_logits = self._action_predictor(pooled_output)
        # Shape: (batch_size, passage_len, 2)
        span_logits = self._span_predictor(passage_representation)
        # Shape: (batch_size, passage_len, 1), (batch_size, passage_len, 1)
        span_start_logits, span_end_logits = span_logits.split(1, dim=2)
        # Shape: (batch_size, passage_len)
        span_start_logits = span_start_logits.squeeze(2)
        # Shape: (batch_size, passage_len)
        span_end_logits = span_end_logits.squeeze(2)

        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "pooled_output": pooled_output,
            "passage_representation": passage_representation,
            "action_logits": action_logits,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        if self.use_scenario_encoding:
            output_dict["sim_token_logits"] = sim_token_logits_wp

        # Compute the loss for training (and for validation)
        if span_start is not None:
            # Shape: (batch_size,)
            span_loss = nll_loss(util.masked_log_softmax(
                span_start_logits, passage_mask),
                                 span_start.squeeze(1),
                                 reduction='none')
            # Shape: (batch_size,)
            span_loss += nll_loss(util.masked_log_softmax(
                span_end_logits, passage_mask),
                                  span_end.squeeze(1),
                                  reduction='none')
            # Shape: (batch_size,)
            more_mask = (label == self.vocab.get_token_index(
                'More', namespace="labels")).float()
            # Shape: (batch_size,)
            span_loss = (span_loss * more_mask).sum() / (more_mask.sum() +
                                                         1e-6)
            if more_mask.sum() > 1e-7:
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(1), more_mask)
                self._span_end_accuracy(span_end_logits, span_end.squeeze(1),
                                        more_mask)
                # Shape: (batch_size, 2)
                span_acc_mask = more_mask.unsqueeze(1).expand(-1, 2).long()
                self._span_accuracy(best_span,
                                    torch.cat([span_start, span_end], dim=1),
                                    span_acc_mask)

            action_loss = cross_entropy(action_logits, label)
            self._action_accuracy(action_logits, label)

            self._span_loss_metric(span_loss.item())
            self._action_loss_metric(action_loss.item())
            output_dict['loss'] = self.loss_weights[
                'span_loss'] * span_loss + self.loss_weights[
                    'action_loss'] * action_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if not self.training:  # true during validation and test
            output_dict['best_span_str'] = []
            batch_size = len(metadata)
            for i in range(batch_size):
                passage_text = metadata[i]['passage_text']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_str = passage_text[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_str)
                if 'gold_span' in metadata[i]:
                    if metadata[i]['action'] == 'More':
                        gold_span = metadata[i]['gold_span']
                        self._squad_metrics(best_span_str, [gold_span])
        return output_dict
Exemple #10
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                numbers_in_passage: Dict[str, torch.LongTensor],
                number_indices: torch.LongTensor,
                answer_as_passage_spans: torch.LongTensor = None,
                answer_as_question_spans: torch.LongTensor = None,
                answer_as_add_sub_expressions: torch.LongTensor = None,
                answer_as_counts: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ, unused-argument

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(embedded_passage)

        encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(passage_question_similarity,
                                                    question_mask,
                                                    memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(passage_question_similarity.transpose(1, 2),
                                                    passage_mask,
                                                    memory_efficient=True)

        # Shape: (batch_size, passage_length, passage_length)
        passsage_attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage, passsage_attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
                torch.cat([encoded_passage, passage_question_vectors,
                           encoded_passage * passage_question_vectors,
                           encoded_passage * passage_passage_vectors],
                          dim=-1))

        # The recurrent modeling layers. Since these layers share the same parameters,
        # we don't construct them conditioned on answering abilities.
        modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]
        for _ in range(4):
            modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)
        # Pop the first one, which is input
        modeled_passage_list.pop(0)

        # The first modeling layer is used to calculate the vector representation of passage
        passage_weights = self._passage_weights_predictor(modeled_passage_list[0]).squeeze(-1)
        passage_weights = masked_softmax(passage_weights, passage_mask)
        passage_vector = util.weighted_sum(modeled_passage_list[0], passage_weights)
        # The vector representation of question is calculated based on the unmatched encoding,
        # because we may want to infer the answer ability only based on the question words.
        question_weights = self._question_weights_predictor(encoded_question).squeeze(-1)
        question_weights = masked_softmax(question_weights, question_mask)
        question_vector = util.weighted_sum(encoded_question, question_weights)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(passage_vector)
            count_number_log_probs = torch.nn.functional.log_softmax(count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = \
                torch.gather(count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self._counting_index]

        if "passage_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, passage_length, modeling_dim * 2))
            passage_for_span_start = torch.cat([modeled_passage_list[0], modeled_passage_list[1]], dim=-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_logits = self._passage_span_start_predictor(passage_for_span_start).squeeze(-1)
            # Shape: (batch_size, passage_length, modeling_dim * 2)
            passage_for_span_end = torch.cat([modeled_passage_list[0], modeled_passage_list[2]], dim=-1)
            # Shape: (batch_size, passage_length)
            passage_span_end_logits = self._passage_span_end_predictor(passage_for_span_end).squeeze(-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_log_probs = util.masked_log_softmax(passage_span_start_logits, passage_mask)
            passage_span_end_log_probs = util.masked_log_softmax(passage_span_end_logits, passage_mask)

            # Info about the best passage span prediction
            passage_span_start_logits = util.replace_masked_values(passage_span_start_logits, passage_mask, -1e7)
            passage_span_end_logits = util.replace_masked_values(passage_span_end_logits, passage_mask, -1e7)
            # Shape: (batch_size, 2)
            best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits)
            # Shape: (batch_size, 2)
            best_passage_start_log_probs = \
                torch.gather(passage_span_start_log_probs, 1, best_passage_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_passage_end_log_probs = \
                torch.gather(passage_span_end_log_probs, 1, best_passage_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs
            if len(self.answering_abilities) > 1:
                best_passage_span_log_prob += answer_ability_log_probs[:, self._passage_span_extraction_index]

        if "question_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, question_length)
            encoded_question_for_span_prediction = \
                torch.cat([encoded_question,
                           passage_vector.unsqueeze(1).repeat(1, encoded_question.size(1), 1)], -1)
            question_span_start_logits = \
                self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
            # Shape: (batch_size, question_length)
            question_span_end_logits = \
                self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
            question_span_start_log_probs = util.masked_log_softmax(question_span_start_logits, question_mask)
            question_span_end_log_probs = util.masked_log_softmax(question_span_end_logits, question_mask)

            # Info about the best question span prediction
            question_span_start_logits = \
                util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
            question_span_end_logits = \
                util.replace_masked_values(question_span_end_logits, question_mask, -1e7)
            # Shape: (batch_size, 2)
            best_question_span = get_best_span(question_span_start_logits, question_span_end_logits)
            # Shape: (batch_size, 2)
            best_question_start_log_probs = \
                torch.gather(question_span_start_log_probs, 1, best_question_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_question_end_log_probs = \
                torch.gather(question_span_end_log_probs, 1, best_question_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_question_span_log_prob = best_question_start_log_probs + best_question_end_log_probs
            if len(self.answering_abilities) > 1:
                best_question_span_log_prob += answer_ability_log_probs[:, self._question_span_extraction_index]

        if "addition_subtraction" in self.answering_abilities:
            # Shape: (batch_size, # of numbers in the passage)
            number_indices = number_indices.squeeze(-1)
            number_mask = (number_indices != -1).long()
            clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0)
            encoded_passage_for_numbers = torch.cat([modeled_passage_list[0], modeled_passage_list[3]], dim=-1)
            # Shape: (batch_size, # of numbers in the passage, encoding_dim)
            encoded_numbers = torch.gather(
                    encoded_passage_for_numbers,
                    1,
                    clamped_number_indices.unsqueeze(-1).expand(-1, -1, encoded_passage_for_numbers.size(-1)))
            # Shape: (batch_size, # of numbers in the passage)
            encoded_numbers = torch.cat(
                    [encoded_numbers, passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(best_signs_for_numbers, number_mask, 0)
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                    number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)
            best_combination_log_prob = best_signs_log_probs.sum(-1)
            if len(self.answering_abilities) > 1:
                best_combination_log_prob += answer_ability_log_probs[:, self._addition_subtraction_index]

        output_dict = {}

        # If answer is given, compute the loss.
        if answer_as_passage_spans is not None or answer_as_question_spans is not None \
                or answer_as_add_sub_expressions is not None or answer_as_counts is not None:

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
                    gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_passage_span_mask = (gold_passage_span_starts != -1).long()
                    clamped_gold_passage_span_starts = \
                        util.replace_masked_values(gold_passage_span_starts, gold_passage_span_mask, 0)
                    clamped_gold_passage_span_ends = \
                        util.replace_masked_values(gold_passage_span_ends, gold_passage_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_span_starts = \
                        torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts)
                    log_likelihood_for_passage_span_ends = \
                        torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_spans = \
                        log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_passage_spans = \
                        util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_passage_span = util.logsumexp(log_likelihood_for_passage_spans)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span)

                elif answering_ability == "question_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_question_span_starts = answer_as_question_spans[:, :, 0]
                    gold_question_span_ends = answer_as_question_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_question_span_mask = (gold_question_span_starts != -1).long()
                    clamped_gold_question_span_starts = \
                        util.replace_masked_values(gold_question_span_starts, gold_question_span_mask, 0)
                    clamped_gold_question_span_ends = \
                        util.replace_masked_values(gold_question_span_ends, gold_question_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_span_starts = \
                        torch.gather(question_span_start_log_probs, 1, clamped_gold_question_span_starts)
                    log_likelihood_for_question_span_ends = \
                        torch.gather(question_span_end_log_probs, 1, clamped_gold_question_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_spans = \
                        log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_question_spans = \
                        util.replace_masked_values(log_likelihood_for_question_spans,
                                                   gold_question_span_mask,
                                                   -1e7)
                    # Shape: (batch_size, )
                    # pylint: disable=invalid-name
                    log_marginal_likelihood_for_question_span = \
                        util.logsumexp(log_likelihood_for_question_spans)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_question_span)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1) > 0).float()
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(number_sign_log_probs, 2, gold_add_sub_signs)
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = \
                        util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0)
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = \
                        util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_add_sub = util.logsumexp(log_likelihood_for_add_subs)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = (answer_as_counts != -1).long()
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(answer_as_counts, gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(count_number_log_probs, 1, clamped_gold_counts)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = \
                        util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_count)

                else:
                    raise ValueError(f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                marginal_log_likelihood = util.logsumexp(all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            output_dict["loss"] = - marginal_log_likelihood.mean()

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])

                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[best_answer_ability[i].detach().cpu().numpy()]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                # We did not consider multi-mention answers here
                if predicted_ability_str == "passage_span_extraction":
                    passage_str = metadata[i]['original_passage']
                    offsets = metadata[i]['passage_token_offsets']
                    predicted_span = tuple(best_passage_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = passage_str[start_offset:end_offset]
                elif predicted_ability_str == "question_span_extraction":
                    question_str = metadata[i]['original_question']
                    offsets = metadata[i]['question_token_offsets']
                    predicted_span = tuple(best_question_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = question_str[start_offset:end_offset]
                elif predicted_ability_str == "addition_subtraction":  # plus_minus combination answer
                    original_numbers = metadata[i]['original_numbers']
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy()]
                    result = sum([sign * number for sign, number in zip(predicted_signs, original_numbers)])
                    predicted_answer = str(result)
                elif predicted_ability_str == "counting":
                    predicted_count = best_count_number[i].detach().cpu().numpy()
                    predicted_answer = str(predicted_count)
                else:
                    raise ValueError(f"Unsupported answer ability: {predicted_ability_str}")

                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(predicted_answer)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)
            # This is used for the demo.
            output_dict["passage_question_attention"] = passage_question_attention
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
            # The demo takes `best_span_str` as a key to find the predicted answer
            output_dict["best_span_str"] = output_dict["answer"]
        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()

        batch_size, passage_len = passage_mask.shape

        span_start_logits = torch.FloatTensor(batch_size, passage_len)
        span_start_logits.zero_()
        span_start_logits.scatter_(1, span_start, 1)

        span_end_logits = torch.FloatTensor(batch_size, passage_len)
        span_end_logits.zero_()
        span_end_logits.scatter_(1, span_end, 1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            # question_tokens = []
            # passage_tokens = []

            all_reference_answers_text = []
            all_best_spans = []
            for i in range(batch_size):
                # question_tokens.append(metadata[i]['question_tokens'])
                # passage_tokens.append(metadata[i]['passage_tokens'])

                predicted_span = tuple(best_span[i].detach().cpu().numpy())

                start_span = predicted_span[0]
                end_span = predicted_span[1]
                best_span_tokens = metadata[i]['passage_tokens'][
                    start_span:end_span + 1]
                best_span_string = " ".join(best_span_tokens)
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
                    all_best_spans.append(best_span_string)
                    all_reference_answers_text.append(answer_texts)

            if not self.training:
                self.calculate_rouge(all_best_spans,
                                     all_reference_answers_text)

            # output_dict['question_tokens'] = question_tokens
            # output_dict['passage_tokens'] = passage_tokens
        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            passage_sem_views_q: torch.IntTensor = None,
            passage_sem_views_k: torch.IntTensor = None,
            question_sem_views_q: torch.IntTensor = None,
            question_sem_views_k: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        passage_sem_views_q : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Query (Q)
        passage_sem_views_k : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Key (K)
        question_sem_views_q : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Query (Q)
        question_sem_views_k : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Key (K)

        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        return_output_metadata = self.return_output_metadata

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()

        batch_size, passage_len = passage_mask.shape

        # convert to long
        if passage_sem_views_q is not None:
            passage_sem_views_q = passage_sem_views_q.long()

        if passage_sem_views_k is not None:
            passage_sem_views_k = passage_sem_views_k.long()

        if question_sem_views_q is not None:
            question_sem_views_q = question_sem_views_q.long()

        if question_sem_views_k is not None:
            question_sem_views_k = question_sem_views_k.long()

        span_start_logits = torch.FloatTensor(batch_size, passage_len)
        span_start_logits.zero_()
        span_start_logits.scatter_(1, span_start, 1)

        span_end_logits = torch.FloatTensor(batch_size, passage_len)
        span_end_logits.zero_()
        span_end_logits.scatter_(1, span_end, 1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

            # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
            # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
            if metadata is not None:
                output_dict['best_span_str'] = []
                question_tokens = []
                passage_tokens = []
                metrics_per_item = None

                all_reference_answers_text = []
                all_best_spans = []

                return_metrics_per_item = True

                if not self.training:
                    metrics_per_item = [{} for x in range(batch_size)]

                for i in range(batch_size):
                    question_tokens.append(metadata[i]['question_tokens'])
                    passage_tokens.append(metadata[i]['passage_tokens'])
                    passage_str = metadata[i]['original_passage']
                    predicted_span = tuple(best_span[i].detach().cpu().numpy())

                    start_span = predicted_span[0]
                    end_span = predicted_span[1]
                    best_span_tokens = metadata[i]['passage_tokens'][
                        start_span:end_span + 1]
                    best_span_string = " ".join(best_span_tokens)
                    output_dict['best_span_str'].append(best_span_string)
                    output_dict['best_span_tokens'] = best_span_tokens
                    answer_texts = metadata[i].get('answer_texts', [])

                    if return_output_metadata:
                        best_span_semantic_features = []
                        curr_item_features = passage_sem_views_q[i]
                        for view_id in range(curr_item_features.shape[0]):
                            curr_view_feats = curr_item_features[view_id][
                                start_span:end_span + 1]
                            best_span_semantic_features.append(
                                curr_view_feats.tolist())

                        output_dict[
                            'best_span_semantic_features'] = best_span_semantic_features

                    all_best_spans.append(best_span_string)

                    if answer_texts:
                        curr_item_em, curr_item_f1 = self._squad_metrics(
                            best_span_string, answer_texts, return_score=True)
                        if not self.training and return_metrics_per_item:
                            metrics_per_item[i]["em"] = curr_item_em
                            metrics_per_item[i]["f1"] = curr_item_f1

                        all_reference_answers_text.append(answer_texts)

                # output metadata
                if return_output_metadata:
                    output_dict["output_metadata"] = {
                        "modeling_layer": {
                            "modeling_layer_iter_000": {
                                "encoder_block_001": {
                                    "semantic_views_q":
                                    passage_sem_views_q,
                                    "semantic_views_sent_mask":
                                    passage_sem_views_k,
                                },
                            }
                        }
                    }

                if not self.training and len(all_reference_answers_text) > 0:
                    metrics_per_item_rouge = self.calculate_rouge(
                        all_best_spans,
                        all_reference_answers_text,
                        return_metrics_per_item=return_metrics_per_item)

                    for i, curr_metrics in enumerate(metrics_per_item_rouge):
                        metrics_per_item[i].update(curr_metrics)

                if metrics_per_item is not None:
                    output_dict['metrics'] = metrics_per_item

                output_dict['question_tokens'] = question_tokens
                output_dict['passage_tokens'] = passage_tokens
        return output_dict
Exemple #13
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ, unused-argument
        passage_mask = util.get_text_field_mask(passage).float()
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        batch_size = embedded_passage.size(0)
        embedded_passage = self._highway_layer(
            self._embedding_proj_layer(embedded_passage))

        encoded_passage_list = [embedded_passage]
        for _ in range(3):
            encoded_passage = self._dropout(
                self._encoding_layer(encoded_passage_list[-1], passage_mask))
            encoded_passage_list.append(encoded_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat(
            [encoded_passage_list[-3], encoded_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat(
            [encoded_passage_list[-3], encoded_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e32)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {}

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span,
                                torch.stack([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(best_span_string)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(best_span_string, answer_annotations)
        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.IntTensor = None,
            span_end: torch.IntTensor = None,
            passage_sem_views_q: torch.IntTensor = None,
            passage_sem_views_k: torch.IntTensor = None,
            question_sem_views_q: torch.IntTensor = None,
            question_sem_views_k: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        passage_sem_views_q : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Query (Q)
        passage_sem_views_k : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Key (K)
        question_sem_views_q : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Query (Q)
        question_sem_views_k : ``torch.IntTensor``, optional
            Paragraph semantic views features for multihead attention Key (K)

        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """

        return_output_metadata = self.return_output_metadata
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()

        if isinstance(self._phrase_layer, QaNetSemanticFlatEncoder) \
            or isinstance(self._phrase_layer, QaNetSemanticFlatConcatEncoder)\
            or isinstance(self._modeling_layer, QaNetSemanticFlatEncoder) \
            or isinstance(self._modeling_layer, QaNetSemanticFlatConcatEncoder):
            if passage_sem_views_q is not None:
                passage_sem_views_q = passage_sem_views_q.long()

            if passage_sem_views_k is not None:
                passage_sem_views_k = passage_sem_views_k.long()

            if question_sem_views_q is not None:
                question_sem_views_q = question_sem_views_q.long()

            if question_sem_views_k is not None:
                question_sem_views_k = question_sem_views_k.long()

        if torch.cuda.is_available():
            # indices
            question_mask = to_cuda(question_mask, move_to_cuda=True)
            passage_mask = to_cuda(passage_mask, move_to_cuda=True)

            question = {
                k: to_cuda(v, move_to_cuda=True)
                for k, v in question.items()
            }
            passage = {
                k: to_cuda(v, move_to_cuda=True)
                for k, v in passage.items()
            }

            # span
            if span_start is not None:
                span_start = to_cuda(span_start, move_to_cuda=True)

            if span_end is not None:
                span_end = to_cuda(span_end, move_to_cuda=True)

            # semantic views
            if passage_sem_views_q is not None:
                passage_sem_views_q = to_cuda(passage_sem_views_q,
                                              move_to_cuda=True)

            if passage_sem_views_k is not None:
                passage_sem_views_k = to_cuda(passage_sem_views_k,
                                              move_to_cuda=True)

            if question_sem_views_q is not None:
                question_sem_views_q = to_cuda(question_sem_views_q,
                                               move_to_cuda=True)

            if question_sem_views_k is not None:
                question_sem_views_k = to_cuda(question_sem_views_k,
                                               move_to_cuda=True)

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(
            self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(
            self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(
            embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(
            embedded_passage)

        encoded_passage_output_metadata = None
        encoded_question_output_metadata = None
        if isinstance(self._phrase_layer, QaNetSemanticFlatEncoder) \
                or  isinstance(self._phrase_layer, QaNetSemanticFlatConcatEncoder):
            if is_output_meta_supported(self._phrase_layer):
                encoded_passage, encoded_passage_output_metadata = self._phrase_layer(
                    projected_embedded_passage, passage_sem_views_q,
                    passage_sem_views_k, passage_mask, return_output_metadata)
                encoded_passage = self._dropout(encoded_passage)

                encoded_question, encoded_question_output_metadata = self._phrase_layer(
                    projected_embedded_question, question_sem_views_q,
                    question_sem_views_k, question_mask,
                    return_output_metadata)
                encoded_question = self._dropout(encoded_question)
            else:
                encoded_passage = self._dropout(
                    self._phrase_layer(projected_embedded_passage,
                                       passage_sem_views_q,
                                       passage_sem_views_k, passage_mask))
                encoded_question = self._dropout(
                    self._phrase_layer(projected_embedded_question,
                                       question_sem_views_q,
                                       question_sem_views_k, question_mask))
        else:
            encoded_passage = self._dropout(
                self._phrase_layer(projected_embedded_passage, passage_mask))
            encoded_question = self._dropout(
                self._phrase_layer(projected_embedded_question, question_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True)

        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2),
            passage_mask,
            memory_efficient=True)

        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention,
                                             question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage,
                                                    attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat([
                encoded_passage, passage_question_vectors,
                encoded_passage * passage_question_vectors,
                encoded_passage * passage_passage_vectors
            ],
                      dim=-1))

        modeled_passage_list = [
            self._modeling_proj_layer(merged_passage_attention_vectors)
        ]
        modeled_passage_output_metadata_list = {}

        for modeling_layer_id in range(3):
            modeled_passage_output_metadata = None
            if isinstance(self._modeling_layer, QaNetSemanticFlatEncoder) \
                    or isinstance(self._modeling_layer, QaNetSemanticFlatConcatEncoder):
                if is_output_meta_supported(self._modeling_layer):
                    modeled_passage, modeled_passage_output_metadata = self._modeling_layer(
                        modeled_passage_list[-1], passage_sem_views_q,
                        passage_sem_views_k, passage_mask,
                        return_output_metadata)
                else:
                    modeled_passage = self._modeling_layer(
                        modeled_passage_list[-1], passage_sem_views_q,
                        passage_sem_views_k, passage_mask)

            else:
                modeled_passage = self._modeling_layer(
                    modeled_passage_list[-1], passage_mask)

            modeled_passage = self._dropout(modeled_passage)
            modeled_passage_list.append(modeled_passage)
            modeled_passage_output_metadata_list[
                "modeling_layer_iter_{0:03d}".format(
                    modeling_layer_id)] = modeled_passage_output_metadata

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat(
            [modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat(
            [modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            #"passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            try:
                loss = nll_loss(
                    util.masked_log_softmax(span_start_logits, passage_mask),
                    span_start.squeeze(-1))
                self._span_start_accuracy(span_start_logits,
                                          span_start.squeeze(-1))
                loss += nll_loss(
                    util.masked_log_softmax(span_end_logits, passage_mask),
                    span_end.squeeze(-1))
                self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
                self._span_accuracy(best_span,
                                    torch.stack([span_start, span_end], -1))
                output_dict["loss"] = loss
            except Exception as e:
                logging.exception(e)

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            metrics_per_item = None

            all_reference_answers_text = []
            all_best_spans = []

            return_metrics_per_item = True

            if not self.training:
                metrics_per_item = [{} for x in range(batch_size)]

            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())

                # offsets = metadata[i]['token_offsets']
                # start_offset = offsets[predicted_span[0]][0]
                # end_offset = offsets[predicted_span[1]][1]

                start_span = predicted_span[0]
                end_span = predicted_span[1]
                best_span_tokens = metadata[i]['passage_tokens'][
                    start_span:end_span + 1]
                best_span_string = " ".join(best_span_tokens)
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])

                all_best_spans.append(best_span_string)

                if answer_texts:
                    curr_item_em, curr_item_f1 = self._squad_metrics(
                        best_span_string, answer_texts, return_score=True)
                    if not self.training and return_metrics_per_item:
                        metrics_per_item[i]["em"] = curr_item_em
                        metrics_per_item[i]["f1"] = curr_item_f1

                    all_reference_answers_text.append(answer_texts)

            if return_output_metadata:
                output_dict["output_metadata"] = {
                    "encoded_passage": encoded_passage_output_metadata,
                    "encoded_question": encoded_question_output_metadata,
                    "modeling_layer": modeled_passage_output_metadata_list,
                }

            if not self.training and len(all_reference_answers_text) > 0:
                metrics_per_item_rouge = self.calculate_rouge(
                    all_best_spans,
                    all_reference_answers_text,
                    return_metrics_per_item=return_metrics_per_item)

                for i, curr_metrics in enumerate(metrics_per_item_rouge):
                    metrics_per_item[i].update(curr_metrics)

            if metrics_per_item is not None:
                output_dict['metrics'] = metrics_per_item

            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens

        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            question_and_passage: Dict[str, torch.LongTensor],
            answer_as_passage_spans: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ, unused-argument

        # logger.info("="*10)
        # logger.info([len(metadata[i]["passage_tokens"]) for i in range(len(metadata))])
        # logger.info([len(metadata[i]["question_tokens"]) for i in range(len(metadata))])
        # logger.info(question_and_passage["bert"].shape)

        # The segment labels should be as following:
        # <CLS> + question_word_pieces + <SEP> + passage_word_pieces + <SEP>
        # 0                0               0              1              1
        # We get this in a tricky way here
        expanded_question_bert_tensor = torch.zeros_like(
            question_and_passage["bert"])
        expanded_question_bert_tensor[:, :question["bert"].
                                      shape[1]] = question["bert"]
        segment_labels = (question_and_passage["bert"] -
                          expanded_question_bert_tensor > 0).long()
        question_and_passage["segment_labels"] = segment_labels
        embedded_question_and_passage = self._text_field_embedder(
            question_and_passage)

        # We also get the passage mask for the concatenated question and passage in a similar way
        expanded_question_mask = torch.zeros_like(question_and_passage["mask"])
        # We shift the 1s to one column right here, to mask the [SEP] token in the middle
        expanded_question_mask[:, 1:question["mask"].shape[1] +
                               1] = question["mask"]
        expanded_question_mask[:, 0] = 1
        passage_mask = question_and_passage["mask"] - expanded_question_mask

        batch_size = embedded_question_and_passage.size(0)

        span_start_logits = self._span_start_predictor(
            embedded_question_and_passage).squeeze(-1)
        span_end_logits = self._span_end_predictor(
            embedded_question_and_passage).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_start_log_probs = util.masked_log_softmax(
            span_start_logits, passage_mask)
        passage_span_end_log_probs = util.masked_log_softmax(
            span_end_logits, passage_mask)

        passage_span_start_logits = util.replace_masked_values(
            span_start_logits, passage_mask, -1e32)
        passage_span_end_logits = util.replace_masked_values(
            span_end_logits, passage_mask, -1e32)
        best_passage_span = get_best_span(passage_span_start_logits,
                                          passage_span_end_logits)

        output_dict = {
            "passage_span_start_probs": passage_span_start_log_probs.exp(),
            "passage_span_end_probs": passage_span_end_log_probs.exp()
        }

        # If answer is given, compute the loss for training.
        if answer_as_passage_spans is not None:
            # Shape: (batch_size, # of answer spans)
            gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
            gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
            # Some spans are padded with index -1,
            # so we clamp those paddings to 0 and then mask after `torch.gather()`.
            gold_passage_span_mask = (gold_passage_span_starts != -1).long()
            clamped_gold_passage_span_starts = util.replace_masked_values(
                gold_passage_span_starts, gold_passage_span_mask, 0)
            clamped_gold_passage_span_ends = util.replace_masked_values(
                gold_passage_span_ends, gold_passage_span_mask, 0)
            # Shape: (batch_size, # of answer spans)
            log_likelihood_for_passage_span_starts = \
                torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts)
            log_likelihood_for_passage_span_ends = \
                torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends)
            # Shape: (batch_size, # of answer spans)
            log_likelihood_for_passage_spans = \
                log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends
            # For those padded spans, we set their log probabilities to be very small negative value
            log_likelihood_for_passage_spans = \
                util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e32)
            # Shape: (batch_size, )
            log_marginal_likelihood_for_passage_span = util.logsumexp(
                log_likelihood_for_passage_spans)
            output_dict[
                "loss"] = -log_marginal_likelihood_for_passage_span.mean()

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                # We did not consider multi-mention answers here
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['passage_token_offsets']
                predicted_span = tuple(
                    best_passage_span[i].detach().cpu().numpy())
                # Remove the offsets of question tokens and the [SEP] token
                predicted_span = (predicted_span[0] -
                                  len(metadata[i]['question_tokens']) - 1,
                                  predicted_span[1] -
                                  len(metadata[i]['question_tokens']) - 1)
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_answer_str = passage_str[start_offset:end_offset]
                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(best_answer_str)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(best_answer_str, answer_annotations)
        return output_dict
Exemple #16
0
    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat(
            [
                encoded_passage,
                passage_question_vectors,
                encoded_passage * passage_question_vectors,
                encoded_passage * tiled_question_passage_vector,
            ],
            dim=-1,
        )

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat(
            [
                final_merged_passage,
                modeled_passage,
                tiled_start_representation,
                modeled_passage * tiled_start_representation,
            ],
            dim=-1,
        )
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end],
                                                     -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            token_offsets = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])
                token_offsets.append(metadata[i]["token_offsets"])
                passage_str = metadata[i]["original_passage"]
                offsets = metadata[i]["token_offsets"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict["best_span_str"].append(best_span_string)
                answer_texts = metadata[i].get("answer_texts", [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
            output_dict["token_offsets"] = token_offsets
        return output_dict
Exemple #17
0
    def forward(self,  # type: ignore
                metadata: Dict,
                tokens: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None
                ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        tokens : Dict[str, torch.LongTensor]
            From a ``TextField`` (that has a bert-pretrained token indexer)
        span_start : torch.IntTensor, optional (default = None)
            A tensor of shape (batch_size, 1) which contains the start_position of the answer
            in the passage, or 0 if impossible. This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : torch.IntTensor, optional (default = None)
            A tensor of shape (batch_size, 1) which contains the end_position of the answer
            in the passage, or 0 if impossible. This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        Returns
        -------
        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            unnormalized log probabilities of the label.
        start_probs: torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        end_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing
            probabilities of the label.
        best_span:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        input_ids = tokens[self._index]
        token_type_ids = tokens[f"{self._index}-type-ids"]
        input_mask = (input_ids != 0).long()

        # 1. Build model here
        bert_output, _ = self.bert_model(input_ids, token_type_ids, attention_mask=input_mask)
        linear_output = self.linear(bert_output)
        linear_dropped = self.drop(linear_output)
        start_logits, end_logits = linear_dropped.split(1, dim=-1)
        start_logits, end_logits = start_logits.squeeze(-1), end_logits.squeeze(-1)

        # 2. Compute start_position and end_position and then get the best span
        # using allennlp.models.reading_comprehension.util.get_best_span()
        masked_soft_start = masked_softmax(start_logits, mask=input_mask)
        masked_soft_end = masked_softmax(end_logits, mask=input_mask)
        best_span = get_best_span(masked_soft_start, masked_soft_end)
        output_dict = {
            "start_logits": start_logits,
            "end_logits": end_logits,
            "start_probs": masked_soft_start,
            "end_probs": masked_soft_end,
            "best_span": best_span
        }

        # 4. Compute loss and accuracies. You should compute at least:
        # span_start accuracy, span_end accuracy and full span accuracy.
        # import ipdb;ipdb.set_trace()
        self._span_start_accuracy(start_logits, span_start.squeeze())
        self._span_end_accuracy(end_logits, span_end.squeeze())
        self._span_accuracy(best_span,  torch.stack([span_start.squeeze(), span_end.squeeze()]))
        

        # UNCOMMENT THIS LINE
        # import ipdb;ipdb.set_trace()
        if span_start is not None:
            ignored_index = start_logits.size(1)
            span_start.clamp_(0, ignored_index)
            span_end.clamp_(0, ignored_index)
            start_loss = self.loss(start_logits, span_start.squeeze(-1))
            end_loss = self.loss(end_logits, span_end.squeeze(-1))
            combined_loss = (start_loss + end_loss) / 2
            output_dict["loss"] = combined_loss

        # 5. Optionally you can compute the official squad metrics (exact match, f1).
        # Instantiate the metric object in __init__ using allennlp.training.metrics.SquadEmAndF1()
        # When you call it, you need to give it the word tokens of the span (implement and call decode() below)
        # and the gold tokens found in metadata[i]['answer_texts']

        return output_dict
Exemple #18
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_start: torch.IntTensor = None,
                span_end: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(embedded_passage)

        encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
                passage_question_similarity,
                question_mask,
                memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
                passage_question_similarity.transpose(1, 2),
                passage_mask,
                memory_efficient=True)
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
                torch.cat([encoded_passage, passage_question_vectors,
                           encoded_passage * passage_question_vectors,
                           encoded_passage * passage_passage_vectors],
                          dim=-1)
                )

        modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]

        for _ in range(3):
            modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
                "passage_question_attention": passage_question_attention,
                "span_start_logits": span_start_logits,
                "span_start_probs": span_start_probs,
                "span_end_logits": span_end_logits,
                "span_end_probs": span_end_probs,
                "best_span": best_span,
                }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits, span_start.squeeze(-1))
            loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end], -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Exemple #19
0
    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                segment_ids: torch.LongTensor = None,
                start_positions: torch.LongTensor = None,
                end_positions: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = tokens['tokens']

        batch_size = input_ids.size(0)
        num_choices = input_ids.size(1)

        tokens_mask = (input_ids != self._padding_value).long()

        if self._debug > 0:
            print(f"batch_size = {batch_size}")
            print(f"num_choices = {num_choices}")
            print(f"tokens_mask = {tokens_mask}")
            print(f"input_ids.size() = {input_ids.size()}")
            print(f"input_ids = {input_ids}")
            print(f"segment_ids = {segment_ids}")
            print(f"start_positions = {start_positions}")
            print(f"end_positions = {end_positions}")

        # Segment ids are not used by RoBERTa

        transformer_outputs = self._transformer_model(
            input_ids=input_ids,
            # token_type_ids=segment_ids,
            attention_mask=tokens_mask)
        sequence_output = transformer_outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        span_start_logits = util.replace_masked_values(start_logits,
                                                       tokens_mask, -1e7)
        span_end_logits = util.replace_masked_values(end_logits, tokens_mask,
                                                     -1e7)
        best_span = get_best_span(span_start_logits, span_end_logits)
        span_start_probs = util.masked_softmax(span_start_logits, tokens_mask)
        span_end_probs = util.masked_softmax(span_end_logits, tokens_mask)
        output_dict = {
            "start_logits": start_logits,
            "end_logits": end_logits,
            "best_span": best_span
        }
        output_dict["start_probs"] = span_start_probs
        output_dict["end_probs"] = span_end_probs

        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            self._span_start_accuracy(span_start_logits, start_positions)
            self._span_end_accuracy(span_end_logits, end_positions)
            self._span_accuracy(
                best_span,
                torch.cat([
                    start_positions.unsqueeze(-1),
                    end_positions.unsqueeze(-1)
                ], -1))

            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index)
            # Should we mask out invalid positions here?
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            output_dict["loss"] = total_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict['best_span_str'] = []
            output_dict['exact_match'] = []
            output_dict['f1_score'] = []
            tokens_texts = []
            for i in range(batch_size):
                tokens_text = metadata[i]['tokens']
                tokens_texts.append(tokens_text)
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                predicted_start = predicted_span[0]
                predicted_end = predicted_span[1]
                predicted_tokens = tokens_text[predicted_start:(predicted_end +
                                                                1)]
                best_span_string = self.convert_tokens_to_string(
                    predicted_tokens)
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                exact_match = 0
                f1_score = 0
                if answer_texts:
                    exact_match, f1_score = self._squad_metrics(
                        best_span_string, answer_texts)
                output_dict['exact_match'].append(exact_match)
                output_dict['f1_score'].append(f1_score)
            output_dict['tokens_texts'] = tokens_texts

        if self._debug > 0:
            print(f"output_dict = {output_dict}")

        return output_dict