Exemple #1
0
    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        embedded_text_input = self.text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self.dropout is not None:
            embedded_text_input = self.dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, mask)

        if self.dropout is not None:
            encoded_text = self.dropout(encoded_text)

        if self.feedforward is not None:
            encoded_text = self.feedforward(encoded_text)

        logits = self.tag_projection_layer(encoded_text)

        output = {'logits': logits, 'mask': mask}

        if labels is not None:
            flipped_mask = (mask == 0)
            masked_labels = labels.masked_fill(flipped_mask, -1)
            output['loss'] = self.loss(logits.transpose(1, 2), masked_labels)
            for name, metric in self.metrics.items():
                metric(logits, labels, mask.float())

        return output
Exemple #2
0
def replace_token(target: torch.LongTensor, old: int, new: int):
    """Replace old tokens with new.

    Arguments:
        target
        old: the token to be replaced by new.
        new: the token used to replace old.

    """
    return target.masked_fill(target == old, new)
    def _update_seq_length_for_generation(
        sequence_lengths: torch.LongTensor,
        unfinished_sequences: torch.LongTensor,
        cur_len: int,
        is_eos_in_next_token: torch.BoolTensor,
    ) -> Tuple[torch.LongTensor, torch.LongTensor]:
        # check if sentence is not finished yet
        is_sent_unfinished = unfinished_sequences.mul(
            is_eos_in_next_token.long()).bool()

        # update sentence length
        sequence_lengths = sequence_lengths.masked_fill(
            is_sent_unfinished, cur_len)
        unfinished_sequences = unfinished_sequences.mul(
            (~is_eos_in_next_token).long())
        return sequence_lengths, unfinished_sequences
Exemple #4
0
    def forward(self,
                tokens: Dict[str, torch.LongTensor],
                token_lengths: torch.Tensor,
                target_tokens: torch.LongTensor = None,
                punct_labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        mask = get_text_field_mask(tokens)
        embedded_text_input = self.text_field_embedder(tokens)
        if self.embedding_dropout is not None:
            embedded_text_input = self.embedding_dropout(embedded_text_input)

        encoded_text = self.encoder(embedded_text_input, mask.bool())
        if self.encoded_dropout is not None:
            encoded_text = self.encoded_dropout(encoded_text)
        if self.feedforward is not None:
            encoded_text = self.feedforward(encoded_text)

        punct_logits = self.punct_projection(encoded_text)
        token_lengths = token_lengths.unsqueeze(-1)
        encoded_text = torch.cat((encoded_text, token_lengths), dim=-1)

        output = {
            'mask': mask,
            'punct_logits': punct_logits,
            'embeddings': encoded_text
        }

        if target_tokens is not None:
            output['loss'] = self.__compute_spellchecker_loss(
                encoded_text, target_tokens)

        if punct_labels is not None:
            flipped_mask = (mask == 0)
            masked_punct_labels = punct_labels.masked_fill(flipped_mask, -1)
            punct_loss = self.losses['punct'](punct_logits.transpose(1, 2),
                                              masked_punct_labels)

            if 'loss' in output:
                output['loss'] += punct_loss
            else:
                output['loss'] = punct_loss

            for name, metric in self.metrics.items():
                metric(punct_logits, punct_labels, mask.float())

        return output
    def forward(self, cw_idxs, cc_idxs, qw_idxs, qc_idxs, ids,
                answer_start_as_passage_spans: torch.LongTensor = None,
                answer_end_as_passage_spans: torch.LongTensor = None,
                answer_as_counts: torch.LongTensor = None,
                number_indices = None):

        batch_size = cw_idxs.size(0)

        # Forward pass equals to QANet up until last layer
        spans_start, spans_end = super().forward(cw_idxs, cc_idxs, qw_idxs, qc_idxs)

        # Modeling layer is used to calculate the vector representation of passage
        passage_weights = masked_softmax(self.passage_weights_layer(self.passage_aware_rep).squeeze(-1), self.c_mask_c2q, log_softmax = False)
        passage_vector_rep = passage_weights.unsqueeze(1).bmm(self.passage_aware_rep).squeeze(1)
        # Modeling layer is use to calculate the vector representation of question
        question_weights = masked_softmax(self.question_weights_layer(self.qb).squeeze(-1), self.q_mask_c2q, log_softmax = False)
        question_vector_rep = question_weights.unsqueeze(1).bmm(self.qb).squeeze(1)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = self.answer_ability_predictor(
                torch.cat([passage_vector_rep, question_vector_rep], -1)
            )
            answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1)
            # Shape: (batch_size,)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, self.max_count)
            count_number_logits = self.count_number_predictor(passage_vector_rep)
            count_number_log_probs = torch.nn.functional.log_softmax(count_number_logits, -1) # softmax over possible numbers
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1) # most probable numeric value
            best_count_log_prob = torch.gather(
                count_number_log_probs, 1, best_count_number.unsqueeze(-1)
            ).squeeze(-1)
            
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self.counting_index]

        # TODO: test or remove
        if "addition_subtraction" in self.answering_abilities:
            
            # M3 (see NAQANet paper)
            modeled_passage = self.modeled_passage_list[-1]
            for block in self.modeling_encoder_blocks:
                modeled_passage = self.dropout_layer(
                    block(modeled_passage, self.c_mask_enc)
                )
            
            self.modeled_passage_list.append(modeled_passage)
            encoded_passage_for_numbers = torch.cat(
                [self.modeled_passage_list[0], self.modeled_passage_list[3]], dim=-1
            )
            
            # create mask on indices. Padding value = -1
            number_mask = number_indices != -1
            clamped_number_indices = number_indices.masked_fill(~number_mask, 0).type(torch.int64).to(self.device)
            number_mask = number_mask.to(self.device)

            if number_mask.size(1) > 0:
                # Shape: (batch_size, max_len_context, 3*hidden_size)
                encoded_numbers = torch.cat(
                    [
                        encoded_passage_for_numbers,
                        passage_vector_rep.unsqueeze(1).repeat(1, encoded_passage_for_numbers.size(1), 1),
                    ],
                    -1,
                )

                # Shape: (batch_size, max # number in passages, 3*hidden_size)
                encoded_numbers = torch.gather(encoded_numbers,
                    1,
                    clamped_number_indices.unsqueeze(-1).expand(
                        -1, -1, encoded_numbers.size(-1)
                    ))

                number_sign_logits = self.number_sign_predictor(encoded_numbers)
                number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1)

                # Shape: (batch_size, # of numbers in passage).
                best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
                # For padding numbers, the best sign masked as 0 (not included).
                best_signs_for_numbers = best_signs_for_numbers.masked_fill(~number_mask, 0)

                # Shape: (batch_size, # of numbers in passage)
                best_signs_log_probs = torch.gather(
                    number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)
                ).squeeze(-1)

                # the probs of the masked positions should be 1 so that it will not affect the joint probability
                # TODO: this is not quite right, since if there are many numbers in the passage,
                # TODO: the joint probability would be very small.
                best_signs_log_probs = best_signs_log_probs.masked_fill(~number_mask, 0)
                # print(f"best_signs_log_probs 3: {best_signs_log_probs}")

                # Shape: (batch_size,)
                best_combination_log_prob = best_signs_log_probs.sum(-1)
                
                if len(self.answering_abilities) > 1:
                    best_combination_log_prob += answer_ability_log_probs[
                        :, self.addition_subtraction_index
                    ]

                else:
                    print("No numbers in the batch")

        
        if "passage_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, passage_length, modeling_dim * 2))
            passage_for_span_start = torch.cat(
                [self.modeled_passage_list[0], self.modeled_passage_list[1]], dim=-1
            )
            # Shape: (batch_size, passage_length)
            passage_span_start_logits = self.passage_span_start_predictor(
                passage_for_span_start
            ).squeeze(-1)
            # Shape: (batch_size, passage_length, modeling_dim * 2)
            passage_for_span_end = torch.cat(
                [self.modeled_passage_list[0], self.modeled_passage_list[2]], dim=-1
            )
            # Shape: (batch_size, passage_length)
            passage_span_end_logits = self.passage_span_end_predictor(
                passage_for_span_end
            ).squeeze(-1)
            # Shape: (batch_size, passage_length). Prob on log scale from -infinite to 0
            passage_span_start_log_probs = util.masked_log_softmax(
                passage_span_start_logits, self.c_mask_c2q
            )
            passage_span_end_log_probs = util.masked_log_softmax(
                passage_span_end_logits, self.c_mask_c2q
            )

            # Info about the best passage span prediction
            passage_span_start_logits = replace_masked_values_with_big_negative_number( \
                passage_span_start_logits, self.c_mask_c2q
            )
            passage_span_end_logits = replace_masked_values_with_big_negative_number(
                passage_span_end_logits, self.c_mask_c2q
            )
            # Shape: (batch_size, 2)
            best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits)
                
            # Shape: (batch_size, 2)
            best_passage_start_log_probs = torch.gather(
                passage_span_start_log_probs, 1, best_passage_span[:, 0].unsqueeze(-1)
            ).squeeze(-1)
            best_passage_end_log_probs = torch.gather(
                passage_span_end_log_probs, 1, best_passage_span[:, 1].unsqueeze(-1)
            ).squeeze(-1)
            # Shape: (batch_size,)
            best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs
            if len(self.answering_abilities) > 1:
                best_passage_span_log_prob += answer_ability_log_probs[
                    :, self.passage_span_extraction_index
                ]

        output_dict = dict()

        # If answer is given, compute the loss.
        if (
            answer_start_as_passage_spans is not None
            or answer_as_add_sub_expressions is not None
            or answer_as_counts is not None
        ):

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_passage_span_starts = answer_start_as_passage_spans
                    gold_passage_span_ends = answer_end_as_passage_spans

                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_passage_span_mask = gold_passage_span_starts != -1 # start and end should share same mask
                    clamped_gold_passage_span_starts = gold_passage_span_starts. \
                            masked_fill(~gold_passage_span_mask, 0)
                    clamped_gold_passage_span_ends = gold_passage_span_ends. \
                            masked_fill(~gold_passage_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_span_starts = torch.gather(
                        passage_span_start_log_probs, 1, clamped_gold_passage_span_starts
                    )
                    log_likelihood_for_passage_span_ends = torch.gather(
                        passage_span_end_log_probs, 1, clamped_gold_passage_span_ends
                    )
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_spans = (
                        log_likelihood_for_passage_span_starts
                        + log_likelihood_for_passage_span_ends
                    )
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_passage_spans = (
                        replace_masked_values_with_big_negative_number(
                            log_likelihood_for_passage_spans,
                            gold_passage_span_mask,
                        )
                    )
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_passage_span = util.logsumexp(
                        log_likelihood_for_passage_spans
                    )

                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span)
                
                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = answer_as_counts != -1
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = answer_as_counts.masked_fill(~gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(
                        count_number_log_probs, 1, clamped_gold_counts
                    )
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_counts, gold_count_mask
                    )
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_count)

                else:
                    raise ValueError(f"Unsupported answering ability: {answering_ability}")
            
            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = (
                    all_log_marginal_likelihoods + answer_ability_log_probs
                )
                marginal_log_likelihood = util.logsumexp(all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            output_dict["loss"] = -marginal_log_likelihood.mean()

        if self.eval_data:
            output_dict["predictions"] = dict()
            for i in range(batch_size):

                id = ids[i].item()
                if len(self.answering_abilities) > 1:
                        predicted_ability_str = self.answering_abilities[
                            best_answer_ability[i].detach().cpu().numpy()
                        ]
                        # print(f"Predicted ability: {predicted_ability_str}")
                else:
                    predicted_ability_str = self.answering_abilities[0]

                if predicted_ability_str == "passage_span_extraction":
                    start = best_passage_span[i, 0]
                    end = best_passage_span[i, 1]
                    preds = convert_tokens(self.eval_data,
                                           id,
                                           start.item(),
                                           end.item())
                    output_dict["predictions"][str(id)] = preds

                elif predicted_ability_str == "counting":
                    predicted_count = str(best_count_number[i].detach().cpu().numpy())
                    output_dict["predictions"][str(id)] = predicted_count


        return output_dict