Exemple #1
0
    def test_logsumexp(self):
        # First a simple example where we add probabilities in log space.
        tensor = torch.FloatTensor([[.4, .1, .2]])
        log_tensor = tensor.log()
        log_summed = util.logsumexp(log_tensor, dim=-1, keepdim=False)
        assert_almost_equal(log_summed.exp().data.numpy(), [.7])
        log_summed = util.logsumexp(log_tensor, dim=-1, keepdim=True)
        assert_almost_equal(log_summed.exp().data.numpy(), [[.7]])

        # Then some more atypical examples, and making sure this will work with how we handle
        # log masks.
        tensor = torch.FloatTensor([[float('-inf'), 20.0]])
        assert_almost_equal(util.logsumexp(tensor).data.numpy(), [20.0])
        tensor = torch.FloatTensor([[-200.0, 20.0]])
        assert_almost_equal(util.logsumexp(tensor).data.numpy(), [20.0])
        tensor = torch.FloatTensor([[20.0, 20.0], [-200.0, 200.0]])
        assert_almost_equal(util.logsumexp(tensor, dim=0).data.numpy(), [20.0, 200.0])
    def _input_likelihood(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the (batch_size,) denominator term for the log-likelihood, which is the
        sum of the likelihoods across all possible state sequences.
        """
        batch_size, sequence_length, num_tags = logits.size()

        # Transpose batch size and sequence dimensions
        mask = mask.float().transpose(0, 1).contiguous()
        logits = logits.transpose(0, 1).contiguous()

        # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the
        # transitions to the initial states and the logits for the first timestep.
        if self.include_start_end_transitions:
            alpha = self.start_transitions.view(1, num_tags) + logits[0]
        else:
            alpha = logits[0]

        # For each i we compute logits for the transitions from timestep i-1 to timestep i.
        # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are
        # (instance, current_tag, next_tag)
        for i in range(1, sequence_length):
            # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
            emit_scores = logits[i].view(batch_size, 1, num_tags)
            # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
            transition_scores = self.transitions.view(1, num_tags, num_tags)
            # Alpha is for the current_tag, so we broadcast along the next_tag axis.
            broadcast_alpha = alpha.view(batch_size, num_tags, 1)

            # Add all the scores together and logexp over the current_tag axis
            inner = broadcast_alpha + emit_scores + transition_scores

            # In valid positions (mask == 1) we want to take the logsumexp over the current_tag dimension
            # of ``inner``. Otherwise (mask == 0) we want to retain the previous alpha.
            alpha = (util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) +
                     alpha * (1 - mask[i]).view(batch_size, 1))

        # Every sequence needs to end with a transition to the stop_tag.
        if self.include_start_end_transitions:
            stops = alpha + self.end_transitions.view(1, num_tags)
        else:
            stops = alpha

        # Finally we log_sum_exp along the num_tags dim, result is (batch_size,)
        return util.logsumexp(stops)
    def decode(self,
               initial_state: State,
               transition_function: TransitionFunction,
               supervision: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, torch.Tensor]:
        targets, target_mask = supervision
        beam_search = ConstrainedBeamSearch(self._beam_size, targets, target_mask)
        finished_states: Dict[int, List[State]] = beam_search.search(initial_state, transition_function)

        loss = 0
        for instance_states in finished_states.values():
            scores = [state.score[0].view(-1) for state in instance_states]
            loss += -util.logsumexp(torch.cat(scores))
        return {'loss': loss / len(finished_states)}
    def decode(self,
               initial_state: DecoderState,
               decode_step: DecoderStep,
               supervision: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, torch.Tensor]:
        targets, target_mask = supervision

        # If self._beam_size is not set, we use a beam size that ensures we keep all of the
        # sequences.
        beam_size = self._beam_size or targets.size(1)
        beam_search = ConstrainedBeamSearch(beam_size, targets, target_mask)
        finished_states: Dict[int, List[DecoderState]] = beam_search.search(initial_state, decode_step)

        loss = 0
        for instance_states in finished_states.values():
            scores = [state.score[0].view(-1) for state in instance_states]
            loss += -util.logsumexp(torch.cat(scores))
        return {'loss': loss / len(finished_states)}
Exemple #5
0
    def _merge_output_dicts(
        self, candidate_output_dicts: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:
        # TODO (pradeep): Also take the worlds and actions and reconstruct logical forms. The issue is that the
        # actions are bottom-up, and DomainLanguage instances can only handle top-down sequences.
        # TODO (pradeep): These losses are batch averaged. Is that a problem?
        # (max_num_inputs,)
        losses = torch.stack(
            [output["loss"] for output in candidate_output_dicts])
        # Losses are negative log-likelihoods. The final loss we need is be the negative log of sum of all
        # likelihoods.
        output_dict = {"loss": -util.logsumexp(-losses)}
        if "class_log_probabilities" in candidate_output_dicts[0]:
            # This means we have an k-best list of sequences.
            # List of (batch_size, k) tensors.
            candidate_log_probabilities = [
                output["class_log_probabilities"]
                for output in candidate_output_dicts
            ]
            # (batch_size, max_num_inputs * k)
            log_probabilities = torch.cat(candidate_log_probabilities, dim=-1)
            # We now merge the predictions. One thing to worry about here is that the sequence lengths may not
            # necessarily be equal to the ``max_decoding_steps``. Given a batch of instances, the BeamSearch stops
            # searching if all instances have reached end states. So we need to do some padding here before
            # concatenating output candidates from different input sequences.
            padded_predictions: List[torch.Tensor] = []
            for candidate_output_dict in candidate_output_dicts:
                candidate_predictions = candidate_output_dict["predictions"]
                batch_size, num_sequences, sequence_length = candidate_predictions.size(
                )
                if sequence_length < self._max_decoding_steps:
                    padding_length = self._max_decoding_steps - sequence_length
                    padding = candidate_predictions.new_full(
                        (batch_size, num_sequences, padding_length),
                        self._end_index)
                    candidate_predictions = torch.cat(
                        [candidate_predictions, padding], dim=2)
                padded_predictions.append(candidate_predictions)
            # (batch_size, max_num_inputs * k, max_decoding_steps)
            predictions = torch.cat(padded_predictions, dim=1)
            sorted_log_probabilities, indices = torch.sort(log_probabilities,
                                                           descending=True)
            # (batch_size, max_num_inputs * k, max_decoding_steps)
            indices_for_selection = indices.unsqueeze(-1).repeat_interleave(
                self._max_decoding_steps, dim=2)
            sorted_predictions = predictions.gather(1, indices_for_selection)
            output_dict["class_log_probabilities"] = sorted_log_probabilities
            output_dict["predictions"] = sorted_predictions

            # Now we rank the action sequences according to the log probabilities of their best decoded sequences.
            # (batch_size, max_num_inputs)
            best_log_probabilities = torch.stack([
                log_probs[:, 0] for log_probs in candidate_log_probabilities
            ]).transpose(0, 1)
            # (batch_size, max_num_inputs)
            _, ranked_input_indices = torch.sort(best_log_probabilities,
                                                 1,
                                                 descending=True)
            int_ranked_input_indices = [[
                int(x) for x in instance_indices.data.cpu()
            ] for instance_indices in ranked_input_indices]
            output_dict[
                "sorted_logical_form_indices"] = int_ranked_input_indices
        return output_dict
Exemple #6
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            numbers_in_passage: Dict[str, torch.LongTensor],
            number_indices: torch.LongTensor,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            answer_as_add_sub_expressions: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ, unused-argument

        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        embedded_passage = self._highway_layer(
            self._text_field_embedder(passage))
        batch_size = embedded_question.size(0)
        passage_length = embedded_passage.size(1)
        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = util.masked_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size, passage_length, encoding_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)

        modeled_passage = self._dropout(
            self._modeling_layer(final_merged_passage, passage_lstm_mask))
        modeling_dim = modeled_passage.size(-1)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        span_start_input = self._dropout(
            torch.cat([final_merged_passage, modeled_passage], dim=-1))
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)
        # Shape: (batch_size, passage_length)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size, modeling_dim)
        span_start_representation = util.weighted_sum(modeled_passage,
                                                      span_start_probs)
        # Shape: (batch_size, passage_length, modeling_dim)
        tiled_start_representation = span_start_representation.unsqueeze(
            1).expand(batch_size, passage_length, modeling_dim)

        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim * 3)
        span_end_representation = torch.cat([
            final_merged_passage, modeled_passage, tiled_start_representation,
            modeled_passage * tiled_start_representation
        ],
                                            dim=-1)
        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(span_end_representation, passage_lstm_mask))
        # Shape: (batch_size, passage_length, encoding_dim * 4 + span_end_encoding_dim)
        span_end_input = self._dropout(
            torch.cat([final_merged_passage, encoded_span_end], dim=-1))
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_start_log_probs = util.masked_log_softmax(
            span_start_logits, passage_mask)
        passage_span_end_log_probs = util.masked_log_softmax(
            span_end_logits, passage_mask)

        passage_span_start_logits = util.replace_masked_values(
            span_start_logits, passage_mask, -1e7)
        passage_span_end_logits = util.replace_masked_values(
            span_end_logits, passage_mask, -1e7)
        # Shape: (batch_size, 2)
        best_passage_span = \
            BidirectionalAttentionFlow.get_best_span(passage_span_start_logits, passage_span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "passage_span_start_probs": passage_span_start_log_probs.exp(),
            "passage_span_end_probs": passage_span_end_log_probs.exp()
        }

        # If answer is given, compute the loss for training.
        if answer_as_passage_spans is not None:
            # Shape: (batch_size, # of answer spans)
            gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
            gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
            # Some spans are padded with index -1,
            # so we clamp those paddings to 0 and then mask after `torch.gather()`.
            gold_passage_span_mask = (gold_passage_span_starts != -1).long()
            clamped_gold_passage_span_starts = util.replace_masked_values(
                gold_passage_span_starts, gold_passage_span_mask, 0)
            clamped_gold_passage_span_ends = util.replace_masked_values(
                gold_passage_span_ends, gold_passage_span_mask, 0)
            # Shape: (batch_size, # of answer spans)
            log_likelihood_for_passage_span_starts = \
                torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts)
            log_likelihood_for_passage_span_ends = \
                torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends)
            # Shape: (batch_size, # of answer spans)
            log_likelihood_for_passage_spans = \
                log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends
            # For those padded spans, we set their log probabilities to be very small negative value
            log_likelihood_for_passage_spans = \
                util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e32)
            # Shape: (batch_size, )
            log_marginal_likelihood_for_passage_span = util.logsumexp(
                log_likelihood_for_passage_spans)
            output_dict[
                "loss"] = -log_marginal_likelihood_for_passage_span.mean()

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                # We did not consider multi-mention answers here
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['passage_token_offsets']
                predicted_span = tuple(
                    best_passage_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_answer_str = passage_str[start_offset:end_offset]
                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(best_answer_str)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(best_answer_str, answer_annotations)
        return output_dict
Exemple #7
0
    def forward(
        self,  # type: ignore
        text: TextFieldTensors,
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        text : `TextFieldTensors`, required.
            The output of a `TextField` representing the text of
            the document.
        spans : `torch.IntTensor`, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of
            indices into the text of the document.
        span_labels : `torch.IntTensor`, optional (default = None).
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : `List[Dict[str, Any]]`, optional (default = None).
            A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys
            from this dictionary, which respectively have the original text and the annotated gold coreference
            clusters for that instance.

        # Returns

        An output dictionary consisting of:
        top_spans : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : `torch.IntTensor`
            A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)).squeeze(-1)
        # Shape:
        #   (batch_size, num_spans) * 3
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep)
        top_span_mention_scores = top_span_mention_scores.unsqueeze(-1)
        top_span_mask = top_span_mask.unsqueeze(-1)

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)
        # Shape: (batch_size, num_spans_to_keep, embedding_size)
        top_span_embeddings = util.batched_index_select(
            span_embeddings, top_span_indices, flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        (
            valid_antecedent_indices,
            valid_antecedent_offsets,
            valid_antecedent_log_mask,
        ) = self._generate_valid_antecedents(  # noqa
            num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings,
            top_span_mention_scores,
            candidate_antecedent_mention_scores,
            valid_antecedent_log_mask,
        )

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
    def forward(
            self,  # type: ignore
            question_passage: Dict[str, torch.LongTensor],
            number_indices: torch.LongTensor,
            mask_indices: torch.LongTensor,
            #num_spans: torch.LongTensor = None,
            impossible_answer: torch.LongTensor = None,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            answer_as_expressions: torch.LongTensor = None,
            answer_as_expressions_extra: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # Shape: (batch_size, seqlen)
        question_passage_tokens = question_passage["tokens"]
        # Shape: (batch_size, seqlen)
        pad_mask = question_passage["mask"]
        # Shape: (batch_size, seqlen)
        seqlen_ids = question_passage["tokens-type-ids"]

        max_seqlen = question_passage_tokens.shape[-1]
        batch_size = question_passage_tokens.shape[0]

        # Shape: (batch_size, 3)
        mask = mask_indices.squeeze(-1)
        # Shape: (batch_size, seqlen)
        cls_sep_mask = \
            torch.ones(pad_mask.shape, device=pad_mask.device).long().scatter(1, mask, torch.zeros(mask.shape, device=mask.device).long())
        # Shape: (batch_size, seqlen)
        passage_mask = seqlen_ids * pad_mask * cls_sep_mask
        # Shape: (batch_size, seqlen)
        question_mask = (1 - seqlen_ids) * pad_mask * cls_sep_mask

        # Shape: (batch_size, seqlen, bert_dim)
        bert_out, _ = self.BERT(question_passage_tokens,
                                seqlen_ids,
                                pad_mask,
                                output_all_encoded_layers=False)
        # Shape: (batch_size, qlen, bert_dim)
        question_end = max(mask[:, 1])
        question_out = bert_out[:, :question_end]
        # Shape: (batch_size, qlen)
        question_mask = question_mask[:, :question_end]
        # Shape: (batch_size, out)
        question_vector = self.summary_vector(question_out, question_mask,
                                              "question")

        passage_out = bert_out
        del bert_out

        # Shape: (batch_size, bert_dim)
        passage_vector = self.summary_vector(passage_out, passage_mask)

        if "arithmetic" in self.answering_abilities and self.arithmetic == "advanced":
            arithmetic_summary = self.summary_vector(passage_out, pad_mask,
                                                     "arithmetic")
            #             arithmetic_summary = self.summary_vector(question_out, question_mask, "arithmetic")

            # Shape: (batch_size, # of numbers in the passage)
            if number_indices.dim() == 3:
                number_indices = number_indices[:, :, 0].long()
            number_mask = (number_indices != -1).long()
            clamped_number_indices = util.replace_masked_values(
                number_indices, number_mask, 0)
            encoded_numbers = torch.gather(
                passage_out, 1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, passage_out.size(-1)))
            op_mask = torch.ones((batch_size, self.num_ops + 1),
                                 device=number_mask.device).long()
            options_mask = torch.cat([op_mask, number_mask], -1)
            ops = self.op_embeddings(
                torch.arange(self.num_ops + 1,
                             device=number_mask.device).expand(batch_size, -1))
            options = torch.cat([self.Wo(ops), self.Wc(encoded_numbers)], 1)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1))
            #print(impossible_answer)
            #impossible_answer[impossible_answer==0] = -1
            #print(impossible_answer)
            #print(answer_ability_logits)
            #answer_ability_logits[:, -1] = answer_ability_logits[:, -1] * impossible_answer.float()
            #print(answer_ability_logits)
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            #answer_ability_log_probs_filtered = answer_ability_log_probs.clone()
            #print(answer_ability_log_probs_filtered.size())
            #print(answer_ability_log_probs_filtered[:-1])
            #answer_ability_log_probs_filtered[:,-1] = answer_ability_log_probs_filtered[:,-1] * impossible_answer.float()
            #print(answer_ability_log_probs_filtered)
            #print(">>>>>>>>>>>>>>>>>>.")
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            count_number_log_probs, best_count_number = self._count_module(
                passage_vector)

        if "passage_span_extraction" in self.answering_abilities:
            passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span = \
                self._passage_span_module(passage_out, passage_mask)

        if "question_span_extraction" in self.answering_abilities:
            question_span_start_log_probs, question_span_end_log_probs, best_question_span = \
                self._question_span_module(passage_vector, question_out, question_mask)

        if "arithmetic" in self.answering_abilities:
            if self.arithmetic == "base":
                number_mask = (number_indices[:, :, 0].long() != -1).long()
                number_sign_log_probs, best_signs_for_numbers, number_mask = \
                    self._base_arithmetic_module(passage_vector, passage_out, number_indices, number_mask)
            else:
                arithmetic_logits, best_expression = \
                    self._adv_arithmetic_module(arithmetic_summary, self.max_explen, options, options_mask, \
                                                   passage_out, pad_mask)
                shapes = arithmetic_logits.shape
                if (1 - (arithmetic_logits != arithmetic_logits)).sum() != (
                        shapes[0] * shapes[1] * shapes[2]):
                    print("bad logits")
                    arithmetic_logits = torch.rand(
                        shapes,
                        device=arithmetic_logits.device,
                        requires_grad=True)

        output_dict = {}
        if self.training:
            # If answer is given, compute the loss.
            if answer_as_passage_spans is not None or answer_as_question_spans is not None \
                    or answer_as_expressions is not None or answer_as_counts is not None:

                log_marginal_likelihood_list = []

                for answering_ability in self.answering_abilities:
                    if answering_ability == "passage_span_extraction":
                        log_marginal_likelihood_for_passage_span = \
                            self._passage_span_log_likelihood(answer_as_passage_spans,
                                                              passage_span_start_log_probs,
                                                              passage_span_end_log_probs)
                        log_marginal_likelihood_list.append(
                            log_marginal_likelihood_for_passage_span)

                    elif answering_ability == "question_span_extraction":
                        log_marginal_likelihood_for_question_span = \
                            self._question_span_log_likelihood(answer_as_question_spans,
                                                               question_span_start_log_probs,
                                                               question_span_end_log_probs)
                        log_marginal_likelihood_list.append(
                            log_marginal_likelihood_for_question_span)

                    elif answering_ability == "arithmetic":
                        if self.arithmetic == "base":
                            log_marginal_likelihood_for_arithmetic = \
                                self._base_arithmetic_log_likelihood(answer_as_expressions,
                                                                     number_sign_log_probs,
                                                                     number_mask,
                                                                     answer_as_expressions_extra, metadata)
                        else:
                            max_explen = answer_as_expressions.shape[-1]
                            possible_exps = answer_as_expressions.shape[1]
                            limit = min(possible_exps, 1000)
                            log_marginal_likelihood_for_arithmetic = \
                                self._adv_arithmetic_log_likelihood(arithmetic_logits[:,:max_explen,:],
                                                                    answer_as_expressions[:,:limit,:].long())
                        log_marginal_likelihood_list.append(
                            log_marginal_likelihood_for_arithmetic)

                    elif answering_ability == "counting":
                        log_marginal_likelihood_for_count = \
                            self._count_log_likelihood(answer_as_counts,
                                                       count_number_log_probs)
                        log_marginal_likelihood_list.append(
                            log_marginal_likelihood_for_count)
                    elif answering_ability == "answer_exists":
                        impossible_answer[impossible_answer == 0] = -1e7
                        impossible_answer[impossible_answer == 1] = 0
                        log_marginal_likelihood_list.append(
                            impossible_answer.type_as(passage_out))
                    else:
                        raise ValueError(
                            f"Unsupported answering ability: {answering_ability}"
                        )

                if len(self.answering_abilities) > 1:
                    # Add the ability probabilities if there are more than one abiliti
                    all_log_marginal_likelihoods = torch.stack(
                        log_marginal_likelihood_list, dim=-1)
                    all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                    marginal_log_likelihood = util.logsumexp(
                        all_log_marginal_likelihoods)
                else:
                    marginal_log_likelihood = log_marginal_likelihood_list[0]

                output_dict["loss"] = -marginal_log_likelihood.mean()
        else:

            with torch.no_grad():
                # Compute the metrics and add the tokenized input to the output.
                if metadata is not None:
                    output_dict["question_id"] = []
                    output_dict["answer"] = []
                    for i in range(batch_size):
                        if len(self.answering_abilities) > 1:
                            predicted_ability_str = self.answering_abilities[
                                best_answer_ability[i]]
                        else:
                            predicted_ability_str = self.answering_abilities[0]
                        answer_json: Dict[str, Any] = {}

                        # We did not consider multi-mention answers here
                        if predicted_ability_str == "passage_span_extraction":
                            answer_json["answer_type"] = "passage_span"
                            answer_json["value"], answer_json["spans"] = \
                                self._span_prediction(question_passage_tokens[i], best_passage_span[i])
                        elif predicted_ability_str == "question_span_extraction":
                            answer_json["answer_type"] = "question_span"
                            answer_json["value"], answer_json["spans"] = \
                                self._span_prediction(question_passage_tokens[i], best_question_span[i])
                        elif predicted_ability_str == "arithmetic":  # plus_minus combination answer
                            answer_json["answer_type"] = "arithmetic"
                            original_numbers = metadata[i]['original_numbers']
                            if self.arithmetic == "base":
                                answer_json["value"], answer_json["numbers"] = \
                                    self._base_arithmetic_prediction(original_numbers, number_indices[i], best_signs_for_numbers[i])
                            else:
                                answer_json["value"], answer_json["expression"] = \
                                    self._adv_arithmetic_prediction(original_numbers, best_expression[i])
                        elif predicted_ability_str == "counting":
                            answer_json["answer_type"] = "count"
                            answer_json["value"], answer_json["count"] = \
                                self._count_prediction(best_count_number[i])
                        elif predicted_ability_str == "answer_exists":
                            answer_json["answer_type"] = "passage_span"
                            answer_json["value"] = "impossible"

                        output_dict["question_id"].append(
                            metadata[i]["question_id"])
                        output_dict["answer"].append(answer_json)

                        output_dict["prediction"] = answer_json["value"]

        return output_dict
Exemple #9
0
    def _gather_final_log_probs(
            self, generation_log_probs: torch.Tensor,
            copy_log_probs: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Combine copy probabilities with generation probabilities for matching tokens.

        Parameters
        ----------
        generation_log_probs : ``torch.Tensor``
            Shape: `(group_size, target_vocab_size)`
        copy_log_probs : ``torch.Tensor``
            Shape: `(group_size, trimmed_source_length)`
        state : ``Dict[str, torch.Tensor]``

        Returns
        -------
        torch.Tensor
            Shape: `(group_size, target_vocab_size + trimmed_source_length)`.
        """
        _, trimmed_source_length = state["source_to_target"].size()
        source_token_ids = state["source_token_ids"]

        # shape: [(batch_size, *)]
        modified_log_probs_list: List[torch.Tensor] = [generation_log_probs]
        for i in range(trimmed_source_length):
            # shape: (group_size,)
            copy_log_probs_slice = copy_log_probs[:, i]
            # `source_to_target` is a matrix of shape (group_size, trimmed_source_length)
            # where element (i, j) is the vocab index of the target token that matches the jth
            # source token in the ith group, if there is one, or the index of the OOV symbol otherwise.
            # We'll use this to add copy scores to corresponding generation scores.
            # shape: (group_size,)
            source_to_target_slice = state["source_to_target"][:, i]
            # The OOV index in the source_to_target_slice indicates that the source
            # token is not in the target vocab, so we don't want to add that copy score
            # to the OOV token.
            copy_log_probs_to_add_mask = (source_to_target_slice !=
                                          self._oov_index).float()
            copy_log_probs_to_add = copy_log_probs_slice + (
                copy_log_probs_to_add_mask + 1e-45).log()
            # shape: (batch_size, 1)
            copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1)
            # shape: (batch_size, 1)
            selected_generation_log_probs = generation_log_probs.gather(
                1, source_to_target_slice.unsqueeze(-1))
            combined_scores = util.logsumexp(
                torch.cat(
                    (selected_generation_log_probs, copy_log_probs_to_add),
                    dim=1))
            generation_log_probs.scatter_(-1,
                                          source_to_target_slice.unsqueeze(-1),
                                          combined_scores.unsqueeze(-1))
            # We have to combine copy scores for duplicate source tokens so that
            # we can find the overall most likely source token. So, if this is the first
            # occurence of this particular source token, we add the log_probs from all other
            # occurences, otherwise we zero it out since it was already accounted for.
            if i < (trimmed_source_length - 1):
                # Sum copy scores from future occurences of source token.
                # shape: (group_size, trimmed_source_length - i)
                source_future_occurences = (source_token_ids[:, (i + 1):] == source_token_ids[:, i].unsqueeze(-1)).float()  # pylint: disable=line-too-long
                # shape: (group_size, trimmed_source_length - i)
                future_copy_log_probs = copy_log_probs[:, (i + 1):] + (
                    source_future_occurences + 1e-45).log()
                # shape: (group_size, 1 + trimmed_source_length - i)
                combined = torch.cat((copy_log_probs_slice.unsqueeze(-1),
                                      future_copy_log_probs),
                                     dim=-1)
                # shape: (group_size,)
                copy_log_probs_slice = util.logsumexp(combined)
            if i > 0:
                # Remove copy log_probs that we have already accounted for.
                # shape: (group_size, i)
                source_previous_occurences = source_token_ids[:, 0:
                                                              i] == source_token_ids[:, i].unsqueeze(
                                                                  -1)
                # shape: (group_size,)
                duplicate_mask = (source_previous_occurences.sum(
                    dim=-1) == 0).float()
                copy_log_probs_slice = copy_log_probs_slice + (duplicate_mask +
                                                               1e-45).log()

            # Finally, we zero-out copy scores that we added to the generation scores
            # above so that we don't double-count them.
            # shape: (group_size,)
            left_over_copy_log_probs = copy_log_probs_slice + (
                1.0 - copy_log_probs_to_add_mask + 1e-45).log()
            modified_log_probs_list.append(
                left_over_copy_log_probs.unsqueeze(-1))

        # shape: (group_size, target_vocab_size + trimmed_source_length)
        modified_log_probs = torch.cat(modified_log_probs_list, dim=-1)

        return modified_log_probs
Exemple #10
0
    def decode(
        self, initial_state: State, transition_function: TransitionFunction,
        supervision: Tuple[torch.Tensor,
                           torch.Tensor]) -> Dict[str, torch.Tensor]:

        targets, target_mask = supervision
        # batch_size x inter_size x action_size x index_size(no use)
        assert len(targets.size()) == 4
        # -> batch_size * inter_size x action_size
        batch_size, inter_size, _, _ = targets.size()

        # TODO: we must keep the shape because the loss_mask
        targets = targets.reshape(batch_size * inter_size, -1)

        target_mask = target_mask.reshape(batch_size * inter_size, -1)

        inter_mask = target_mask.sum(dim=1).ne(0)

        # un squeeze beam search dimension
        targets = targets.unsqueeze(dim=1)
        target_mask = target_mask.unsqueeze(dim=1)

        beam_search = ConstrainedBeamSearch(self._beam_size, targets,
                                            target_mask)
        finished_states: Dict[int, List[State]] = beam_search.search(
            initial_state, transition_function)

        inter_count = inter_mask.view(batch_size,
                                      inter_size).sum(dim=0).float()
        if 0 not in inter_count:
            inter_ratio = 1.0 / inter_count
        else:
            inter_ratio = torch.ones_like(inter_count)

        loss = 0

        for iter_ind, instance_states in finished_states.items():
            scores = [state.score[0].view(-1) for state in instance_states]
            lens = [len(state.action_history[0]) for state in instance_states]
            if not len(lens):
                continue
            # the i-round of an interaction, starting from 0
            cur_inter = iter_ind % inter_size
            if self._re_weight:
                loss_coefficient = inter_ratio[cur_inter]
            else:
                loss_coefficient = 1.0

            if self._loss_mask <= cur_inter:
                continue

            cur_loss = -util.logsumexp(
                torch.cat(scores)) / statistics.mean(lens)
            loss += loss_coefficient * cur_loss

        if self._re_weight:
            return {'loss': loss / len(inter_count)}
        elif self._loss_mask < inter_size:
            valid_counts = inter_count[:self._loss_mask].sum()
            return {'loss': loss / valid_counts}
        else:
            return {'loss': loss / len(finished_states)}
    def forward(
            self,  # type: ignore
            question_passage_tokens: torch.LongTensor,
            question_passage_token_type_ids: torch.LongTensor,
            question_passage_special_tokens_mask: torch.LongTensor,
            question_passage_pad_mask: torch.LongTensor,
            first_wordpiece_mask: torch.LongTensor,
            metadata: List[Dict[str, Any]],
            wordpiece_indices: torch.LongTensor = None,
            number_indices: torch.LongTensor = None,
            answer_as_expressions: torch.LongTensor = None,
            answer_as_expressions_extra: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            answer_as_text_to_disjoint_bios: torch.LongTensor = None,
            answer_as_list_of_bios: torch.LongTensor = None,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            span_bio_labels: torch.LongTensor = None,
            is_bio_mask: torch.LongTensor = None) -> Dict[str, Any]:
        # pylint: disable=arguments-differ
        question_passage_special_tokens_mask = (
            1 - question_passage_special_tokens_mask)

        batch_size = question_passage_tokens.shape[0]
        head_count = len(self._heads)

        # TODO: (not important) Create a new field that is converted to Dict[str, torch.LongTensor]
        gold_answer_representations = {
            'answer_as_expressions': answer_as_expressions,
            'answer_as_expressions_extra': answer_as_expressions_extra,
            'answer_as_passage_spans': answer_as_passage_spans,
            'answer_as_question_spans': answer_as_question_spans,
            'answer_as_counts': answer_as_counts,
            'answer_as_text_to_disjoint_bios': answer_as_text_to_disjoint_bios,
            'answer_as_list_of_bios': answer_as_list_of_bios,
            'span_bio_labels': span_bio_labels
        }

        has_answer = False
        for answer_representation in gold_answer_representations.values():
            if answer_representation is not None:
                has_answer = True
                break

        # Shape: (batch_size, seqlen)
        passage_mask = question_passage_token_type_ids * question_passage_pad_mask * question_passage_special_tokens_mask
        # Shape: (batch_size, seqlen)
        question_mask = \
            (1 - question_passage_token_type_ids) * question_passage_pad_mask * question_passage_special_tokens_mask
        question_and_passage_mask = question_mask | passage_mask

        # Use pre-trained model to compute the representations of the input data
        # Shape: (batch_size, seqlen, bert_dim)
        token_type_ids = question_passage_token_type_ids if not self._pretrained_model.startswith(
            'roberta-') else None
        token_representations = self._transformers_model(
            question_passage_tokens,
            token_type_ids=token_type_ids,
            attention_mask=question_passage_pad_mask)[0]

        # if desired, compute the passage summary vector
        if self._passage_summary_vector_module is not None:
            # Shape: (batch_size, bert_dim)
            passage_summary_vector = self.summary_vector(
                token_representations, passage_mask, 'passage')
        else:
            passage_summary_vector = None

        # if desired, compute the question summary vector
        if self._question_summary_vector_module is not None:
            # Shape: (batch_size, bert_dim)
            question_summary_vector = self.summary_vector(
                token_representations, question_mask, 'question')
        else:
            question_summary_vector = None

        if head_count > 1:
            # use the head_predictor with the summary vectors
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._head_predictor(torch.cat([passage_summary_vector, question_summary_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            top_answer_abilities = torch.argsort(answer_ability_log_probs,
                                                 descending=True)
        else:
            top_answer_abilities = torch.zeros(batch_size, 1, dtype=torch.int)

        kwargs = {
            'token_representations':
            token_representations,
            'passage_summary_vector':
            passage_summary_vector,
            'question_summary_vector':
            question_summary_vector,
            'gold_answer_representations':
            gold_answer_representations,
            'question_and_passage_mask':
            question_and_passage_mask,
            'first_wordpiece_mask':
            first_wordpiece_mask,
            'is_bio_mask':
            is_bio_mask,
            'wordpiece_indices':
            wordpiece_indices,
            'number_indices':
            number_indices,
            'passage_mask':
            passage_mask,
            'question_mask':
            question_mask,
            'question_passage_special_tokens_mask':
            question_passage_special_tokens_mask
        }

        head_outputs = {}
        for head_name, head in self._heads.items():
            head_outputs[head_name] = head(**kwargs)

        output_dict = {}
        # If answer is given, compute the loss.
        if has_answer:
            log_marginal_likelihood_list = []
            for head_name, head in self._heads.items():
                # The marginal log likelihood is calculated for each head separately
                log_marginal_likelihood = head.gold_log_marginal_likelihood(
                    **kwargs, **head_outputs[head_name])
                """ log probability for each head to be selected is added
                (which is like AND/multiplication, but in logspace). """
                log_marginal_likelihood_list.append(log_marginal_likelihood)

            if head_count > 1:
                # Add the ability probabilities if there is more than one ability
                """ all the likelihoods are combined by summation 
                (this is like OR, as we want to maximize the probability that any of the heads is right, 
                and not that all of them are right).  """
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                marginal_log_likelihood = util.logsumexp(
                    all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            # Finally, we compute the mean loss across the batch elements.
            # put the loss in the output dictionary
            output_dict['loss'] = -1 * marginal_log_likelihood.mean()

        with torch.no_grad():
            # Compute the metrics and add fields to the output
            if metadata is not None and self._training_evaluation:
                if not self.training:
                    output_dict['passage_id'] = []
                    output_dict['query_id'] = []
                    output_dict['answer'] = []
                    output_dict['predicted_ability'] = []
                    output_dict['maximizing_ground_truth'] = []
                    output_dict['em'] = []
                    output_dict['f1'] = []
                    output_dict['max_passage_length'] = []
                    if self._output_all_answers:
                        output_dict['all_answers'] = []

                i = 0
                no_fallback = False
                ordered_lookup_index = 0
                while i < batch_size:
                    predicting_head_index = top_answer_abilities[i][
                        ordered_lookup_index].item()
                    predicting_head_name = self.heads_indices(
                    )[predicting_head_index]
                    predicting_head = self._heads[predicting_head_name]

                    # construct the arguments to be used for a batch instance prediction
                    instance_kwargs = {
                        'q_text':
                        metadata[i]['original_question'],
                        'p_text':
                        metadata[i]['original_passage'],
                        'qp_tokens':
                        metadata[i]['question_passage_tokens'],
                        'question_passage_wordpieces':
                        metadata[i]['question_passage_wordpieces'],
                        'original_numbers':
                        metadata[i]['original_numbers']
                        if 'original_numbers' in metadata[i] else None,
                    }

                    # keys that cannot be passed because
                    # they are not batch-based in their first level or None
                    unpassable_keys = ['gold_answer_representations']

                    for key, value in instance_kwargs.items():
                        if value is None:
                            unpassable_keys.append(key)
                    for key in unpassable_keys:
                        if key in instance_kwargs:
                            del instance_kwargs[key]

                    for key, value in kwargs.items():
                        if value is not None and key not in unpassable_keys:
                            instance_kwargs[key] = value[i]
                    for key, value in head_outputs[predicting_head_name].items(
                    ):
                        if key not in unpassable_keys:
                            instance_kwargs[key] = value[i]

                    # get prediction for an instance in the batch
                    answer_json = predicting_head.decode_answer(
                        **instance_kwargs)

                    if len(answer_json['value']) != 0 or no_fallback:
                        # for the next in the batch
                        ordered_lookup_index = 0
                        no_fallback = False
                    else:
                        if not self.training:
                            logger.info(
                                "Answer was empty for head: %s, query_id: %s",
                                predicting_head_name,
                                metadata[i]['question_id'])
                        ordered_lookup_index += 1
                        if ordered_lookup_index == head_count:
                            no_fallback = True
                            ordered_lookup_index = 0
                        continue

                    maximizing_ground_truth = None
                    em, f1 = None, None
                    answer_annotations = metadata[i].get(
                        'answer_annotations', [])
                    if answer_annotations:
                        (em, f1), maximizing_ground_truth = self._metrics.call(
                            answer_json['value'], answer_annotations,
                            predicting_head_name)

                    if not self.training:
                        output_dict['passage_id'].append(
                            metadata[i]['passage_id'])
                        output_dict['query_id'].append(
                            metadata[i]['question_id'])
                        output_dict['answer'].append(answer_json)
                        output_dict['predicted_ability'].append(
                            predicting_head_name)
                        output_dict['maximizing_ground_truth'].append(
                            maximizing_ground_truth)
                        output_dict['em'].append(em)
                        output_dict['f1'].append(f1)
                        output_dict['max_passage_length'].append(
                            metadata[i]['max_passage_length'])

                        if self._output_all_answers:
                            answers_dict = {}
                            output_dict['all_answers'].append(answers_dict)
                            for j in range(len(self._heads)):
                                predicting_head_index = top_answer_abilities[
                                    i][j].item()
                                predicting_head_name = self.heads_indices(
                                )[predicting_head_index]
                                predicting_head = self._heads[
                                    predicting_head_name]

                                # construct the arguments to be used for a batch instance prediction
                                instance_kwargs = {
                                    'q_text':
                                    metadata[i]['original_question'],
                                    'p_text':
                                    metadata[i]['original_passage'],
                                    'qp_tokens':
                                    metadata[i]['question_passage_tokens'],
                                    'question_passage_wordpieces':
                                    metadata[i]['question_passage_wordpieces'],
                                    'original_numbers':
                                    metadata[i]['original_numbers']
                                    if 'original_numbers' in metadata[i] else
                                    None,
                                }

                                # keys that cannot be passed because
                                # they are not batch-based in their first level or None
                                unpassable_keys = [
                                    'gold_answer_representations'
                                ]

                                for key, value in instance_kwargs.items():
                                    if value is None:
                                        unpassable_keys.append(key)
                                for key in unpassable_keys:
                                    if key in instance_kwargs:
                                        del instance_kwargs[key]

                                for key, value in kwargs.items():
                                    if value is not None and key not in unpassable_keys:
                                        instance_kwargs[key] = value[i]
                                for key, value in head_outputs[
                                        predicting_head_name].items():
                                    if key not in unpassable_keys:
                                        instance_kwargs[key] = value[i]

                                # get prediction for an instance in the batch
                                answer_json = predicting_head.decode_answer(
                                    **instance_kwargs)
                                answer_json[
                                    'probability'] = torch.nn.functional.softmax(
                                        answer_ability_logits,
                                        -1)[i][predicting_head_index].item()
                                answers_dict[
                                    predicting_head_name] = answer_json

                    i += 1

        return output_dict
Exemple #12
0
    def forward(
        self,  # type: ignore
        question_field: Dict[str, torch.LongTensor],
        visual_feat: torch.Tensor,
        pos: torch.Tensor,
        image_id: List[str],
        gold_question_attentions: torch.Tensor = None,
        identifier: List[str] = None,
        logical_form: List[str] = None,
        actions: List[List[ProductionRule]] = None,
        target_action_sequence: torch.LongTensor = None,
        gold_object_choices: torch.Tensor = None,
        denotation: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        batch_size, obj_num, feat_size = visual_feat.size()
        assert obj_num == 36 and feat_size == 2048
        text_masks = util.get_text_field_mask(question_field)
        (l_orig, v_orig, text, vis_only), x_orig = self._encoder(
            question_field[self._tokens_namespace], text_masks, visual_feat,
            pos)

        text_masks = text_masks.float()
        # NOTE: Taking the lxmert output before cross modality layer (which is the same for both images)
        # Can also try concatenating (dim=-1) the two encodings
        encoded_sentence = text

        initial_state = self._get_initial_state(encoded_sentence, text_masks,
                                                actions)
        initial_state.debug_info = [[] for _ in range(batch_size)]

        if target_action_sequence is not None:
            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            target_action_sequence = target_action_sequence.squeeze(-1)
            target_mask = target_action_sequence != self._action_padding_index
        else:
            target_mask = None

        outputs: Dict[str, torch.Tensor] = {}
        losses = []
        if (self.training or self._use_gold_program_for_eval
            ) and target_action_sequence is not None:
            # target_action_sequence is of shape (batch_size, 1, sequence_length) here after we
            # unsqueeze it for the MML trainer.
            search = ConstrainedBeamSearch(
                beam_size=None,
                allowed_sequences=target_action_sequence.unsqueeze(1),
                allowed_sequence_mask=target_mask.unsqueeze(1),
            )
            final_states = search.search(initial_state,
                                         self._transition_function)
            if self._training_batches_so_far < self._num_parse_only_batches:
                for batch_index in range(batch_size):
                    if not final_states[batch_index]:
                        logger.error(
                            f"No pogram found for batch index {batch_index}")
                        continue
                    losses.append(-final_states[batch_index][0].score[0])
        else:
            final_states = self._beam_search.search(
                self._max_decoding_steps,
                initial_state,
                self._transition_function,
                keep_final_unfinished_states=False,
            )

        action_mapping = {}
        for action_index, action in enumerate(actions[0]):
            action_mapping[action_index] = action[0]

        outputs: Dict[str, Any] = {"action_mapping": action_mapping}
        outputs["best_action_sequence"] = []
        outputs["debug_info"] = []

        if self._nmn_settings["mask_non_attention"]:
            zero_one_mult = torch.zeros_like(gold_question_attentions)
            zero_one_mult.copy_(gold_question_attentions)
            zero_one_mult[:, :, 0] = 1.0
            # sep_indices = text_masks.argmax(1).long()
            sep_indices = (
                (text_masks.long() *
                 (1 + torch.arange(text_masks.shape[1]).unsqueeze(0).repeat(
                     batch_size, 1).to(text_masks.device))).argmax(1).long())
            sep_indices = (sep_indices.unsqueeze(1).repeat(
                1, gold_question_attentions.shape[2]).unsqueeze(1).repeat(
                    1, gold_question_attentions.shape[1], 1))
            indices_dim2 = (torch.arange(
                gold_question_attentions.shape[2]).unsqueeze(0).repeat(
                    gold_question_attentions.shape[0],
                    gold_question_attentions.shape[1],
                    1,
                ).to(sep_indices.device).long())
            zero_one_mult = torch.where(
                sep_indices == indices_dim2,
                torch.ones_like(zero_one_mult),
                zero_one_mult,
            ).float()
            reshaped_questions = (
                question_field[self._tokens_namespace].unsqueeze(1).repeat(
                    1, gold_question_attentions.shape[1],
                    1).view(-1, gold_question_attentions.shape[-1]))
            reshaped_visual_feat = (visual_feat.unsqueeze(1).repeat(
                1, gold_question_attentions.shape[1], 1,
                1).view(-1, obj_num, visual_feat.shape[-1]))
            reshaped_pos = (pos.unsqueeze(1).repeat(
                1, gold_question_attentions.shape[1], 1,
                1).view(-1, obj_num, pos.shape[-1]))
            zero_one_mult = zero_one_mult.view(
                -1, gold_question_attentions.shape[-1])
            q_att_filter = zero_one_mult.sum(1) > 2
            (l_relevant, v_relevant, _, _), x_relevant = self._encoder(
                reshaped_questions[q_att_filter, :],
                zero_one_mult[q_att_filter, :],
                reshaped_visual_feat[q_att_filter, :, :],
                reshaped_pos[q_att_filter, :, :],
            )
            l = [{} for _ in range(batch_size)]
            v = [{} for _ in range(batch_size)]
            x = [{} for _ in range(batch_size)]
            count = 0
            batch_index = -1
            for i in range(zero_one_mult.shape[0]):
                module_num = i % target_action_sequence.shape[1]
                if module_num == 0:
                    batch_index += 1
                if q_att_filter[i].item():
                    l[batch_index][module_num] = l_relevant[count]
                    v[batch_index][module_num] = v_relevant[count]
                    x[batch_index][module_num] = x_relevant[count]
                    count += 1
        else:
            l = l_orig
            v = v_orig
            x = x_orig

        for batch_index in range(batch_size):
            if (self.training and self._training_batches_so_far <
                    self._num_parse_only_batches):
                continue
            if not final_states[batch_index]:
                logger.error(f"No pogram found for batch index {batch_index}")
                outputs["best_action_sequence"].append([])
                outputs["debug_info"].append([])
                continue
            world = VisualReasoningGqaLanguage(
                l[batch_index],
                v[batch_index],
                x[batch_index],
                # initial_state.rnn_state[batch_index].encoder_outputs[batch_index],
                self._language_parameters,
                pos[batch_index],
                nmn_settings=self._nmn_settings,
            )

            denotation_log_prob_list = []
            # TODO(mattg): maybe we want to limit the number of states we evaluate (programs we
            # execute) at test time, just for efficiency.
            for state_index, state in enumerate(final_states[batch_index]):
                action_indices = state.action_history[0]
                action_strings = [
                    action_mapping[action_index]
                    for action_index in action_indices
                ]
                # Shape: (num_denotations,)
                assert len(action_strings) == len(state.debug_info[0])
                # Plug in gold question attentions
                for i in range(len(state.debug_info[0])):
                    if gold_question_attentions[batch_index, i, :].sum() > 0:
                        state.debug_info[0][i]["question_attention"] = (
                            gold_question_attentions[batch_index,
                                                     i, :].float() /
                            gold_question_attentions[batch_index, i, :].sum())
                    elif self._nmn_settings["mask_non_attention"] and (
                            action_strings[i][-4:] == "find"
                            or action_strings[i][-6:] == "filter"
                            or action_strings[i][-13:] == "with_relation"):
                        state.debug_info[0][i]["question_attention"] = (
                            torch.ones_like(
                                gold_question_attentions[batch_index,
                                                         i, :]).float() /
                            gold_question_attentions[batch_index,
                                                     i, :].numel())
                        l[batch_index][i] = l_orig[batch_index]
                        v[batch_index][i] = v_orig[batch_index]
                        x[batch_index][i] = x_orig[batch_index]
                        world = VisualReasoningGqaLanguage(
                            l[batch_index],
                            v[batch_index],
                            x[batch_index],
                            # initial_state.rnn_state[batch_index].encoder_outputs[batch_index],
                            self._language_parameters,
                            pos[batch_index],
                            nmn_settings=self._nmn_settings,
                        )
                # print(action_strings)
                state_denotation_log_probs = world.execute_action_sequence(
                    action_strings, state.debug_info[0])
                # prob2 = world.execute_action_sequence(action_strings, state.debug_info[0])

                # P(denotation | parse) * P(parse | question)
                denotation_log_prob_list.append(state_denotation_log_probs)

                if not self._use_gold_program_for_eval:
                    denotation_log_prob_list[-1] += state.score[0]
                if state_index == 0:
                    outputs["best_action_sequence"].append(action_strings)
                    outputs["debug_info"].append(state.debug_info[0])
                    if target_action_sequence is not None:
                        targets = target_action_sequence[batch_index].data
                        program_correct = self._action_history_match(
                            action_indices, targets)
                        self._program_accuracy(program_correct)

            # P(denotation | parse) * P(parse | question) for the all programs on the beam.
            # Shape: (beam_size, num_denotations)
            denotation_log_probs = torch.stack(denotation_log_prob_list)
            # \Sum_parse P(denotation | parse) * P(parse | question) = P(denotation | question)
            # Shape: (num_denotations,)
            marginalized_denotation_log_probs = util.logsumexp(
                denotation_log_probs, dim=0)
            if denotation is not None:
                loss = (self.loss(
                    state_denotation_log_probs.unsqueeze(0),
                    denotation[batch_index].unsqueeze(0).float(),
                ).view(1) * self._denotation_loss_multiplier)
                losses.append(loss)
                self._denotation_accuracy(
                    torch.tensor([
                        1 - state_denotation_log_probs,
                        state_denotation_log_probs
                    ]).to(denotation.device),
                    denotation[batch_index],
                )
                if gold_object_choices is not None:
                    gold_objects = gold_object_choices[batch_index, :, :]
                    predicted_objects = torch.zeros_like(gold_objects)
                    for index in world.object_scores:
                        predicted_objects[
                            index, :] = world.object_scores[index]
                    obj_exists = gold_objects.max(1)[0] > 0
                    # Only look at modules where at least one of the proposals has the object of interest
                    predicted_objects = predicted_objects[obj_exists, :]
                    gold_objects = gold_objects[obj_exists, :]
                    gold_objects = gold_objects.view(-1)
                    predicted_objects = predicted_objects.view(-1)
                    if gold_objects.numel() > 0:
                        loss += self._obj_loss_multiplier * self.loss(
                            predicted_objects, (gold_objects.float() + 1) / 2)
                        self._proposal_accuracy(
                            torch.cat(
                                (
                                    1.0 - predicted_objects.view(-1, 1),
                                    predicted_objects.view(-1, 1),
                                ),
                                dim=-1,
                            ),
                            (gold_objects + 1) / 2,
                        )
        if losses:
            outputs["loss"] = torch.stack(losses).mean()
        if self.training:
            self._training_batches_so_far += 1
        return outputs
Exemple #13
0
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            spans: torch.IntTensor,
            span_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        if self._use_gold_mentions:
            if text_embeddings.is_cuda: device = torch.device('cuda')
            else: device = torch.device('cpu')

            s = [
                torch.as_tensor(pair, dtype=torch.long, device=device)
                for cluster in metadata[0]["clusters"] for pair in cluster
            ]
            gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1)

            span_mask = (spans.unsqueeze(2) - gm)
            span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1]
                                                        == 0)
            span_mask, _ = (span_mask == 2).max(-1)
            num_spans = span_mask.sum().item()
            span_mask = span_mask.float()
        else:
            span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
            num_spans = spans.size(1)
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            coreference_log_probs = util.last_dim_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            question_and_passage: Dict[str, torch.LongTensor],
            answer_as_passage_spans: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ, unused-argument

        # logger.info("="*10)
        # logger.info([len(metadata[i]["passage_tokens"]) for i in range(len(metadata))])
        # logger.info([len(metadata[i]["question_tokens"]) for i in range(len(metadata))])
        # logger.info(question_and_passage["bert"].shape)

        # The segment labels should be as following:
        # <CLS> + question_word_pieces + <SEP> + passage_word_pieces + <SEP>
        # 0                0               0              1              1
        # We get this in a tricky way here
        expanded_question_bert_tensor = torch.zeros_like(
            question_and_passage["bert"])
        expanded_question_bert_tensor[:, :question["bert"].
                                      shape[1]] = question["bert"]
        segment_labels = (question_and_passage["bert"] -
                          expanded_question_bert_tensor > 0).long()
        question_and_passage["segment_labels"] = segment_labels
        embedded_question_and_passage = self._text_field_embedder(
            question_and_passage)

        # We also get the passage mask for the concatenated question and passage in a similar way
        expanded_question_mask = torch.zeros_like(question_and_passage["mask"])
        # We shift the 1s to one column right here, to mask the [SEP] token in the middle
        expanded_question_mask[:, 1:question["mask"].shape[1] +
                               1] = question["mask"]
        expanded_question_mask[:, 0] = 1
        passage_mask = question_and_passage["mask"] - expanded_question_mask

        batch_size = embedded_question_and_passage.size(0)

        span_start_logits = self._span_start_predictor(
            embedded_question_and_passage).squeeze(-1)
        span_end_logits = self._span_end_predictor(
            embedded_question_and_passage).squeeze(-1)

        # Shape: (batch_size, passage_length)
        passage_span_start_log_probs = util.masked_log_softmax(
            span_start_logits, passage_mask)
        passage_span_end_log_probs = util.masked_log_softmax(
            span_end_logits, passage_mask)

        passage_span_start_logits = util.replace_masked_values(
            span_start_logits, passage_mask, -1e32)
        passage_span_end_logits = util.replace_masked_values(
            span_end_logits, passage_mask, -1e32)
        best_passage_span = get_best_span(passage_span_start_logits,
                                          passage_span_end_logits)

        output_dict = {
            "passage_span_start_probs": passage_span_start_log_probs.exp(),
            "passage_span_end_probs": passage_span_end_log_probs.exp()
        }

        # If answer is given, compute the loss for training.
        if answer_as_passage_spans is not None:
            # Shape: (batch_size, # of answer spans)
            gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
            gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
            # Some spans are padded with index -1,
            # so we clamp those paddings to 0 and then mask after `torch.gather()`.
            gold_passage_span_mask = (gold_passage_span_starts != -1).long()
            clamped_gold_passage_span_starts = util.replace_masked_values(
                gold_passage_span_starts, gold_passage_span_mask, 0)
            clamped_gold_passage_span_ends = util.replace_masked_values(
                gold_passage_span_ends, gold_passage_span_mask, 0)
            # Shape: (batch_size, # of answer spans)
            log_likelihood_for_passage_span_starts = \
                torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts)
            log_likelihood_for_passage_span_ends = \
                torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends)
            # Shape: (batch_size, # of answer spans)
            log_likelihood_for_passage_spans = \
                log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends
            # For those padded spans, we set their log probabilities to be very small negative value
            log_likelihood_for_passage_spans = \
                util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e32)
            # Shape: (batch_size, )
            log_marginal_likelihood_for_passage_span = util.logsumexp(
                log_likelihood_for_passage_spans)
            output_dict[
                "loss"] = -log_marginal_likelihood_for_passage_span.mean()

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                # We did not consider multi-mention answers here
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['passage_token_offsets']
                predicted_span = tuple(
                    best_passage_span[i].detach().cpu().numpy())
                # Remove the offsets of question tokens and the [SEP] token
                predicted_span = (predicted_span[0] -
                                  len(metadata[i]['question_tokens']) - 1,
                                  predicted_span[1] -
                                  len(metadata[i]['question_tokens']) - 1)
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_answer_str = passage_str[start_offset:end_offset]
                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(best_answer_str)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(best_answer_str, answer_annotations)
        return output_dict
Exemple #15
0
    def forward(
            self,  # type: ignore
            question_passage: Dict[str, torch.LongTensor],
            number_indices: torch.LongTensor,
            mask_indices: torch.LongTensor,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            answer_as_expressions: torch.LongTensor = None,
            answer_as_expressions_extra: torch.LongTensor = None,
            answer_as_unit_spans: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            answer_as_text_to_disjoint_bios: torch.LongTensor = None,
            answer_as_list_of_bios: torch.LongTensor = None,
            answer_as_yesno: torch.LongTensor = None,
            span_bio_labels: torch.LongTensor = None,
            bio_wordpiece_mask: torch.LongTensor = None,
            is_bio_mask: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # Shape: (batch_size, seqlen)
        question_passage_tokens = question_passage["tokens"]
        # Shape: (batch_size, seqlen)
        pad_mask = question_passage["mask"]
        # Shape: (batch_size, seqlen)
        seqlen_ids = question_passage["tokens-type-ids"]

        max_seqlen = question_passage_tokens.shape[-1]
        batch_size = question_passage_tokens.shape[0]

        # Shape: (batch_size, 3)
        mask = mask_indices.squeeze(-1)
        # Shape: (batch_size, seqlen)
        cls_sep_mask = \
            torch.ones(pad_mask.shape, device=pad_mask.device).long().scatter(1, mask, torch.zeros(mask.shape, device=mask.device).long())
        # Shape: (batch_size, seqlen)
        passage_mask = seqlen_ids * pad_mask * cls_sep_mask
        # Shape: (batch_size, seqlen)
        question_mask = (1 - seqlen_ids) * pad_mask * cls_sep_mask

        question_and_passage_mask = question_mask | passage_mask
        if bio_wordpiece_mask is None or not self.multispan_use_bio_wordpiece_mask:
            multispan_mask = question_and_passage_mask
        else:
            multispan_mask = question_and_passage_mask * bio_wordpiece_mask

        # Shape: (batch_size, seqlen, bert_dim)
        bert_out, _ = self.BERT(question_passage_tokens, seqlen_ids, pad_mask)
        # Shape: (batch_size, qlen, bert_dim)
        question_end = max(mask[:, 1])
        question_out = bert_out[:, :question_end]
        # Shape: (batch_size, qlen)
        question_mask = question_mask[:, :question_end]
        # Shape: (batch_size, out)
        question_vector = self.summary_vector(question_out, question_mask,
                                              "question")

        passage_out = bert_out
        del bert_out

        # Shape: (batch_size, bert_dim)
        passage_vector = self.summary_vector(passage_out, passage_mask)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)
            top_two_answer_abilities = torch.topk(answer_ability_log_probs,
                                                  k=2,
                                                  dim=1)

        if "counting" in self.answering_abilities:
            count_number_log_probs, best_count_number = self._count_module(
                passage_vector)

        if "passage_span_extraction" in self.answering_abilities:
            passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span = \
                self._passage_span_module(passage_out, passage_mask)

        if "question_span_extraction" in self.answering_abilities:
            question_span_start_log_probs, question_span_end_log_probs, best_question_span = \
                self._question_span_module(passage_vector, question_out, question_mask)

        if "arithmetic" in self.answering_abilities or "counting" in self.answering_abilities:
            unit_span_start_log_probs, unit_span_end_log_probs, best_unit_span = \
                self._unit_span_module(passage_vector, question_out, question_mask)

        if "multiple_spans" in self.answering_abilities:
            if self.multispan_head_name == "flexible_loss":
                multispan_log_probs, multispan_logits = self._multispan_module(
                    passage_out, seq_mask=multispan_mask)
            else:
                multispan_log_probs, multispan_logits = self._multispan_module(
                    passage_out)

        if "arithmetic" in self.answering_abilities:
            number_mask = (number_indices[:, :, 0].long() != -1).long()
            number_sign_log_probs, best_signs_for_numbers, number_mask = \
                self._base_arithmetic_module(passage_vector, passage_out, number_indices, number_mask)

        if "yesno" in self.answering_abilities:
            yesno_log_probs, best_yesno = self._yesno_module(passage_vector)

        output_dict = {}
        del passage_out, question_out
        # If answer is given, compute the loss.
        if answer_as_passage_spans is not None or answer_as_question_spans is not None \
                or answer_as_expressions is not None or answer_as_counts is not None \
                or answer_as_yesno is not None or span_bio_labels is not None:

            log_marginal_likelihood_list = []

            ###
            log_marginal_likelihood_for_unit_span = \
                    self._question_span_log_likelihood(answer_as_unit_spans,
                                                unit_span_start_log_probs,
                                                unit_span_end_log_probs)
            ###

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    log_marginal_likelihood_for_passage_span = \
                        self._passage_span_log_likelihood(answer_as_passage_spans,
                                                          passage_span_start_log_probs,
                                                          passage_span_end_log_probs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_passage_span)

                elif answering_ability == "question_span_extraction":
                    log_marginal_likelihood_for_question_span = \
                        self._question_span_log_likelihood(answer_as_question_spans,
                                                           question_span_start_log_probs,
                                                           question_span_end_log_probs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_question_span)

                elif answering_ability == "arithmetic":
                    log_marginal_likelihood_for_arithmetic = \
                        self._base_arithmetic_log_likelihood(answer_as_expressions,
                                                                number_sign_log_probs,
                                                                number_mask,
                                                                answer_as_expressions_extra)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_arithmetic +
                        log_marginal_likelihood_for_unit_span * 0.5)

                elif answering_ability == "counting":
                    log_marginal_likelihood_for_count = \
                        self._count_log_likelihood(answer_as_counts,
                                                   count_number_log_probs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_count +
                        log_marginal_likelihood_for_unit_span * 0.5)

                elif answering_ability == "multiple_spans":
                    if self.multispan_head_name == "flexible_loss":
                        log_marginal_likelihood_for_multispan = \
                            self._multispan_log_likelihood(answer_as_text_to_disjoint_bios,
                                                        answer_as_list_of_bios,
                                                        span_bio_labels,
                                                        multispan_log_probs,
                                                        multispan_logits,
                                                        multispan_mask,
                                                        bio_wordpiece_mask,
                                                        is_bio_mask)
                    else:
                        log_marginal_likelihood_for_multispan = \
                            self._multispan_log_likelihood(span_bio_labels,
                                                        multispan_log_probs,
                                                        multispan_mask,
                                                        is_bio_mask,
                                                        logits=multispan_logits)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_multispan)
                elif answering_ability == "yesno":
                    log_marginal_likelihood_for_yesno = \
                        self._yesno_log_likelihood(answer_as_yesno,
                                                   yesno_log_probs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_yesno)
                else:
                    raise ValueError(
                        f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # import pdb; pdb.set_trace()
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                marginal_log_likelihood = util.logsumexp(
                    all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            output_dict["loss"] = -marginal_log_likelihood.mean()
        with torch.no_grad():
            # Compute the metrics and add the tokenized input to the output.
            if metadata is not None:
                if not self.training:
                    output_dict["passage_id"] = []
                    output_dict["query_id"] = []
                    output_dict["answer"] = []
                    output_dict["predicted_ability"] = []
                    output_dict["maximizing_ground_truth"] = []
                    output_dict["em"] = []
                    output_dict["f1"] = []
                    output_dict["invalid_spans"] = []
                    output_dict["max_passage_length"] = []

                i = 0
                while i < batch_size:
                    if len(self.answering_abilities) > 1:
                        predicted_ability_str = self.answering_abilities[
                            best_answer_ability[i]]
                    else:
                        predicted_ability_str = self.answering_abilities[0]

                    answer_json: Dict[str, Any] = {}

                    invalid_spans = []

                    q_text = metadata[i]['original_question']
                    p_text = metadata[i]['original_passage']
                    qp_tokens = metadata[i]['question_passage_tokens']

                    ###

                    if predicted_ability_str == "passage_span_extraction":
                        answer_json["answer_type"] = "passage_span"
                        answer_json["value"], answer_json["spans"] = \
                            self._span_prediction(qp_tokens, best_passage_span[i], p_text, q_text, 'p')
                    elif predicted_ability_str == "question_span_extraction":
                        answer_json["answer_type"] = "question_span"
                        answer_json["value"], answer_json["spans"] = \
                            self._span_prediction(qp_tokens, best_question_span[i], p_text, q_text, 'q')
                        # import pdb; pdb.set_trace()
                    elif predicted_ability_str == "arithmetic":  # plus_minus combination answer
                        answer_json["answer_type"] = "arithmetic"
                        original_numbers = metadata[i]['original_numbers']
                        answer_json["value"], answer_json["numbers"] = \
                            self._base_arithmetic_prediction(original_numbers, number_indices[i], best_signs_for_numbers[i])
                    elif predicted_ability_str == "counting":
                        answer_json["answer_type"] = "count"
                        answer_json["value"], answer_json["count"] = \
                            self._count_prediction(best_count_number[i])
                    elif predicted_ability_str == "multiple_spans":
                        answer_json["answer_type"] = "multiple_spans"
                        if self.multispan_head_name == "flexible_loss":
                            answer_json["value"], answer_json["spans"], invalid_spans = \
                                self._multispan_prediction(multispan_log_probs[i], multispan_logits[i], qp_tokens, p_text, q_text,
                                                        multispan_mask[i], bio_wordpiece_mask[i], self.multispan_use_prediction_beam_search and not self.training)
                        else:
                            answer_json["value"], answer_json["spans"], invalid_spans = \
                                self._multispan_prediction(multispan_log_probs[i], multispan_logits[i], qp_tokens, p_text, q_text,
                                                        multispan_mask[i])
                        if self._unique_on_multispan:
                            answer_json["value"] = list(
                                OrderedDict.fromkeys(answer_json["value"]))

                            if self._dont_add_substrings_to_ms:
                                answer_json[
                                    "value"] = remove_substring_from_prediction(
                                        answer_json["value"])

                        if len(answer_json["value"]) == 0:
                            best_answer_ability[
                                i] = top_two_answer_abilities.indices[i][1]
                            continue
                    elif predicted_ability_str == "yesno":
                        answer_json["answer_type"] = "yesno"
                        answer_json["value"], answer_json["yesno"] = \
                            self._yesno_prediction(best_yesno[i])
                    else:
                        raise ValueError(
                            f"Unsupported answer ability: {predicted_ability_str}"
                        )

                    if predicted_ability_str == "counting" or predicted_ability_str == "arithmetic":
                        answer_json["unit_value"], answer_json["unit_spans"] = \
                            self._span_prediction(qp_tokens, best_unit_span[i], p_text, q_text, 'q')
                        answer_json["value"] = answer_json[
                            "value"] + answer_json["unit_value"]

                    maximizing_ground_truth = None
                    em, f1 = None, None
                    answer_annotations = metadata[i].get(
                        'answer_annotations', [])

                    if answer_annotations:
                        (em, f1
                         ), maximizing_ground_truth = self._drop_metrics.call(
                             answer_json["value"], [
                                 dict((key, answer_annotation[key]
                                       ) if key != 'number' else (
                                           key, answer_annotation[key] +
                                           answer_annotation['unit'])
                                      for key in answer_annotation.keys())
                                 for answer_annotation in answer_annotations
                             ], predicted_ability_str)

                    if not self.training:
                        output_dict["passage_id"].append(
                            metadata[i]["passage_id"])
                        output_dict["query_id"].append(
                            metadata[i]["question_id"])
                        output_dict["answer"].append(answer_json)
                        output_dict["predicted_ability"].append(
                            predicted_ability_str)
                        output_dict["maximizing_ground_truth"].append(
                            maximizing_ground_truth)
                        output_dict["em"].append(em)
                        output_dict["f1"].append(f1)
                        output_dict["invalid_spans"].append(invalid_spans)
                        output_dict["max_passage_length"].append(
                            metadata[i]["max_passage_length"])

                    i += 1

        return output_dict
Exemple #16
0
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        doc_span_offsets: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        doc_truth_spans: torch.IntTensor = None,
        doc_spans_in_truth: torch.IntTensor = None,
        doc_relation_labels: torch.Tensor = None,
        truth_spans: List[Set[Tuple[int, int]]] = None,
        # doc_relations = None,
        doc_ner_labels: torch.IntTensor = None,
        **metadata: Dict[str, List[Any]]
    ) -> Dict[str, torch.Tensor]:  # add matrix from datareader
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        doc_ner_labels : ``torch.IntTensor``.
            A tensor of shape # TODO,
            ...
        doc_span_offsets : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_truth_spans : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, 1),
            ...
        doc_spans_in_truth : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_relation_labels : ``torch.Tensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans),
            ...

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        batch_size = len(spans)
        document_length = text_embeddings.size(1)
        max_sentence_length = max(
            len(sentence_tokens) for doc_tokens in metadata['doc_tokens']
            for sentence_tokens in doc_tokens)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # TODO features dropout
        # Shape: (batch_size, num_spans, embedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_relex_spans_to_keep = int(
            math.floor(self._relex_spans_per_word * max_sentence_length))

        # Shapes:
        # (batch_size, num_spans_to_keep, span_dim),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # Shape: (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = dict()

        # Store raw text and tokens for decoding step
        output_dict["flat_tokens"] = metadata["flat_tokens"]
        output_dict["flat_text"] = metadata["flat_text"]

        output_dict["top_spans"] = top_spans
        output_dict["antecedent_indices"] = valid_antecedent_indices
        output_dict["predicted_antecedents"] = predicted_antecedents

        if metadata is not None:
            output_dict["document"] = metadata["original_text"]

        # Shape: (,)
        loss = 0

        # Shape: (batch_size, max_sentences, max_spans)
        doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float()
        # Shape: (batch_size, max_sentences, num_spans, span_dim)
        doc_span_embeddings = util.batched_index_select(
            span_embeddings,
            doc_span_offsets.squeeze(-1).long().clamp(min=0))

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        pruned = self._relex_mention_pruner(
            doc_span_embeddings,
            doc_span_mask,
            num_items_to_keep=num_relex_spans_to_keep,
            pass_through=['num_items_to_keep'])
        (top_relex_span_embeddings, top_relex_span_mask,
         top_relex_span_indices, top_relex_span_mention_scores) = pruned

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        top_relex_span_mask = top_relex_span_mask.unsqueeze(-1)

        # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2)  # TODO do we need for a mask?
        doc_spans = util.batched_index_select(
            spans,
            doc_span_offsets.clamp(0).squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2)
        top_relex_spans = nd_batched_index_select(doc_spans,
                                                  top_relex_span_indices)

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep).
        (relex_span_pair_embeddings,
         relex_span_pair_mask) = self._compute_relex_span_pair_embeddings(
             top_relex_span_embeddings, top_relex_span_mask.squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels)
        relex_scores = self._compute_relex_scores(
            relex_span_pair_embeddings, top_relex_span_mention_scores)
        output_dict['relex_scores'] = relex_scores
        output_dict['top_relex_spans'] = top_relex_spans

        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)
            antecedent_labels_ = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long(
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability x to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs)
            negative_marginal_log_likelihood *= top_span_mask.squeeze(
                -1).float()
            negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum(
            )

            # TODO Modify metadata format
            # self._mention_recall(top_spans, metadata)
            # self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            coref_loss = negative_marginal_log_likelihood
            output_dict['coref_loss'] = coref_loss
            loss += self._loss_coref_weight * coref_loss

        if doc_relation_labels is not None:

            # The adjacency matrix for relation extraction is very sparse.
            # As it is not just sparse, but row/column sparse (only few
            # rows and columns are non-zero and in that case these rows/columns
            # are not sparse), we implemented our own matrix for the case.
            # Here we have indices of truth spans and mapping, using which
            # we map prediction matrix on truth matrix.
            # TODO Add teacher forcing support.

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep),
            relative_indices = top_relex_span_indices
            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1),
            compressed_indices = nd_batched_padded_index_select(
                doc_spans_in_truth, relative_indices)

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans)
            gold_pruned_rows = nd_batched_padded_index_select(
                doc_relation_labels,
                compressed_indices.squeeze(-1),
                padding_value=0)
            gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3,
                                                        2).contiguous()

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep)
            gold_pruned_matrices = nd_batched_padded_index_select(
                gold_pruned_rows,
                compressed_indices.squeeze(-1),
                padding_value=0)  # pad with epsilon
            gold_pruned_matrices = gold_pruned_matrices.permute(
                0, 1, 3, 2).contiguous()

            # TODO log_mask relex score before passing
            relex_loss = nd_cross_entropy_with_logits(relex_scores,
                                                      gold_pruned_matrices,
                                                      relex_span_pair_mask)
            output_dict['relex_loss'] = relex_loss

            self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2),
                                       truth_spans)

            # To calculate F1 score, we need to to call decode step
            output_dict = self.decode(output_dict)
            self._compute_relex_metrics(output_dict['raw_interactions'],
                                        metadata['doc_raw_relations'])

            loss += self._loss_relex_weight * relex_loss

        if doc_ner_labels is not None:
            # Shape: (batch_size, max_sentences, num_spans, num_ner_classes)
            ner_scores = self._ner_scorer(doc_span_embeddings)
            output_dict['ner_scores'] = ner_scores

            ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels,
                                                    doc_span_mask)
            output_dict['ner_loss'] = ner_loss
            loss += self._loss_ner_weight * ner_loss

        if not isinstance(loss, int):  # If loss is not yet modified
            output_dict["loss"] = loss

        return output_dict
Exemple #17
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            number_indices: torch.LongTensor,
            answer_as_passage_spans: torch.LongTensor = None,
            answer_as_question_spans: torch.LongTensor = None,
            answer_as_add_sub_expressions: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        question_mask = util.get_text_field_mask(question).float()
        passage_mask = util.get_text_field_mask(passage).float()
        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(
            self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(
            self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(
            embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(
            embedded_passage)

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2),
            passage_mask,
            memory_efficient=True)

        # Shape: (batch_size, passage_length, passage_length)
        passsage_attention_over_attention = torch.bmm(
            passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(
            encoded_passage, passsage_attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat([
                encoded_passage, passage_question_vectors,
                encoded_passage * passage_question_vectors,
                encoded_passage * passage_passage_vectors
            ],
                      dim=-1))

        # The recurrent modeling layers. Since these layers share the same parameters,
        # we don't construct them conditioned on answering abilities.
        modeled_passage_list = [
            self._modeling_proj_layer(merged_passage_attention_vectors)
        ]
        for _ in range(4):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)
        # Pop the first one, which is input
        modeled_passage_list.pop(0)

        # The first modeling layer is used to calculate the vector representation of passage
        passage_weights = self._passage_weights_predictor(
            modeled_passage_list[0]).squeeze(-1)
        passage_weights = masked_softmax(passage_weights, passage_mask)
        passage_vector = util.weighted_sum(modeled_passage_list[0],
                                           passage_weights)
        # The vector representation of question is calculated based on the unmatched encoding,
        # because we may want to infer the answer ability only based on the question words.
        question_weights = self._question_weights_predictor(
            encoded_question).squeeze(-1)
        question_weights = masked_softmax(question_weights, question_mask)
        question_vector = util.weighted_sum(encoded_question, question_weights)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(passage_vector)
            count_number_log_probs = torch.nn.functional.log_softmax(
                count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = \
                torch.gather(count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self.
                                                                _counting_index]

        if "passage_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, passage_length, modeling_dim * 2))
            passage_for_span_start = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[1]], dim=-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_logits = self._passage_span_start_predictor(
                passage_for_span_start).squeeze(-1)
            # Shape: (batch_size, passage_length, modeling_dim * 2)
            passage_for_span_end = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[2]], dim=-1)
            # Shape: (batch_size, passage_length)
            passage_span_end_logits = self._passage_span_end_predictor(
                passage_for_span_end).squeeze(-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_log_probs = util.masked_log_softmax(
                passage_span_start_logits, passage_mask)
            passage_span_end_log_probs = util.masked_log_softmax(
                passage_span_end_logits, passage_mask)

            # Info about the best passage span prediction
            passage_span_start_logits = util.replace_masked_values(
                passage_span_start_logits, passage_mask, -1e7)
            passage_span_end_logits = util.replace_masked_values(
                passage_span_end_logits, passage_mask, -1e7)
            # Shape: (batch_size, 2)
            best_passage_span = get_best_span(passage_span_start_logits,
                                              passage_span_end_logits)
            # Shape: (batch_size, 2)
            best_passage_start_log_probs = \
                torch.gather(passage_span_start_log_probs, 1, best_passage_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_passage_end_log_probs = \
                torch.gather(passage_span_end_log_probs, 1, best_passage_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs
            if len(self.answering_abilities) > 1:
                best_passage_span_log_prob += answer_ability_log_probs[:, self.
                                                                       _passage_span_extraction_index]

        if "question_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, question_length)
            encoded_question_for_span_prediction = \
                torch.cat([encoded_question,
                           passage_vector.unsqueeze(1).repeat(1, encoded_question.size(1), 1)], -1)
            question_span_start_logits = \
                self._question_span_start_predictor(encoded_question_for_span_prediction).squeeze(-1)
            # Shape: (batch_size, question_length)
            question_span_end_logits = \
                self._question_span_end_predictor(encoded_question_for_span_prediction).squeeze(-1)
            question_span_start_log_probs = util.masked_log_softmax(
                question_span_start_logits, question_mask)
            question_span_end_log_probs = util.masked_log_softmax(
                question_span_end_logits, question_mask)

            # Info about the best question span prediction
            question_span_start_logits = \
                util.replace_masked_values(question_span_start_logits, question_mask, -1e7)
            question_span_end_logits = \
                util.replace_masked_values(question_span_end_logits, question_mask, -1e7)
            # Shape: (batch_size, 2)
            best_question_span = get_best_span(question_span_start_logits,
                                               question_span_end_logits)
            # Shape: (batch_size, 2)
            best_question_start_log_probs = \
                torch.gather(question_span_start_log_probs, 1, best_question_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_question_end_log_probs = \
                torch.gather(question_span_end_log_probs, 1, best_question_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_question_span_log_prob = best_question_start_log_probs + best_question_end_log_probs
            if len(self.answering_abilities) > 1:
                best_question_span_log_prob += answer_ability_log_probs[:,
                                                                        self.
                                                                        _question_span_extraction_index]

        if "addition_subtraction" in self.answering_abilities:
            # Shape: (batch_size, # of numbers in the passage)
            number_indices = number_indices.squeeze(-1)
            number_mask = (number_indices != -1).long()
            clamped_number_indices = util.replace_masked_values(
                number_indices, number_mask, 0)
            encoded_passage_for_numbers = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[3]], dim=-1)
            # Shape: (batch_size, # of numbers in the passage, encoding_dim)
            encoded_numbers = torch.gather(
                encoded_passage_for_numbers, 1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_passage_for_numbers.size(-1)))
            # Shape: (batch_size, # of numbers in the passage)
            encoded_numbers = torch.cat([
                encoded_numbers,
                passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1),
                                                   1)
            ], -1)

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(
                number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(
                best_signs_for_numbers, number_mask, 0)
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                number_sign_log_probs, 2,
                best_signs_for_numbers.unsqueeze(-1)).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(
                best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)
            best_combination_log_prob = best_signs_log_probs.sum(-1)
            if len(self.answering_abilities) > 1:
                best_combination_log_prob += answer_ability_log_probs[:, self.
                                                                      _addition_subtraction_index]

        output_dict = {}

        # If answer is given, compute the loss.
        if answer_as_passage_spans is not None or answer_as_question_spans is not None \
                or answer_as_add_sub_expressions is not None or answer_as_counts is not None:

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
                    gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_passage_span_mask = (gold_passage_span_starts !=
                                              -1).long()
                    clamped_gold_passage_span_starts = \
                        util.replace_masked_values(gold_passage_span_starts, gold_passage_span_mask, 0)
                    clamped_gold_passage_span_ends = \
                        util.replace_masked_values(gold_passage_span_ends, gold_passage_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_span_starts = \
                        torch.gather(passage_span_start_log_probs, 1, clamped_gold_passage_span_starts)
                    log_likelihood_for_passage_span_ends = \
                        torch.gather(passage_span_end_log_probs, 1, clamped_gold_passage_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_spans = \
                        log_likelihood_for_passage_span_starts + log_likelihood_for_passage_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_passage_spans = \
                        util.replace_masked_values(log_likelihood_for_passage_spans, gold_passage_span_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_passage_span = util.logsumexp(
                        log_likelihood_for_passage_spans)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_passage_span)

                elif answering_ability == "question_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_question_span_starts = answer_as_question_spans[:, :,
                                                                         0]
                    gold_question_span_ends = answer_as_question_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_question_span_mask = (gold_question_span_starts !=
                                               -1).long()
                    clamped_gold_question_span_starts = \
                        util.replace_masked_values(gold_question_span_starts, gold_question_span_mask, 0)
                    clamped_gold_question_span_ends = \
                        util.replace_masked_values(gold_question_span_ends, gold_question_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_span_starts = \
                        torch.gather(question_span_start_log_probs, 1, clamped_gold_question_span_starts)
                    log_likelihood_for_question_span_ends = \
                        torch.gather(question_span_end_log_probs, 1, clamped_gold_question_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_spans = \
                        log_likelihood_for_question_span_starts + log_likelihood_for_question_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_question_spans = \
                        util.replace_masked_values(log_likelihood_for_question_spans,
                                                   gold_question_span_mask,
                                                   -1e7)
                    # Shape: (batch_size, )
                    # pylint: disable=invalid-name
                    log_marginal_likelihood_for_question_span = \
                        util.logsumexp(log_likelihood_for_question_spans)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_question_span)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1)
                                         > 0).float()
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(
                        1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(
                        number_sign_log_probs, 2, gold_add_sub_signs)
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = \
                        util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0)
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(
                        1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = \
                        util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_add_sub = util.logsumexp(
                        log_likelihood_for_add_subs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = (answer_as_counts != -1).long()
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(
                        answer_as_counts, gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(
                        count_number_log_probs, 1, clamped_gold_counts)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = \
                        util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_count = util.logsumexp(
                        log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_count)

                else:
                    raise ValueError(
                        f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                marginal_log_likelihood = util.logsumexp(
                    all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            output_dict["loss"] = -marginal_log_likelihood.mean()

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])

                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[
                        best_answer_ability[i].detach().cpu().numpy()]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                answer_json: Dict[str, Any] = {}

                # We did not consider multi-mention answers here
                if predicted_ability_str == "passage_span_extraction":
                    answer_json["answer_type"] = "passage_span"
                    passage_str = metadata[i]['original_passage']
                    offsets = metadata[i]['passage_token_offsets']
                    predicted_span = tuple(
                        best_passage_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = passage_str[start_offset:end_offset]
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif predicted_ability_str == "question_span_extraction":
                    answer_json["answer_type"] = "question_span"
                    question_str = metadata[i]['original_question']
                    offsets = metadata[i]['question_token_offsets']
                    predicted_span = tuple(
                        best_question_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = question_str[start_offset:end_offset]
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif predicted_ability_str == "addition_subtraction":  # plus_minus combination answer
                    answer_json["answer_type"] = "arithmetic"
                    original_numbers = metadata[i]['original_numbers']
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [
                        sign_remap[it] for it in
                        best_signs_for_numbers[i].detach().cpu().numpy()
                    ]
                    result = sum([
                        sign * number for sign, number in zip(
                            predicted_signs, original_numbers)
                    ])
                    predicted_answer = str(result)
                    offsets = metadata[i]['passage_token_offsets']
                    number_indices = metadata[i]['number_indices']
                    number_positions = [
                        offsets[index] for index in number_indices
                    ]
                    answer_json['numbers'] = []
                    for offset, value, sign in zip(number_positions,
                                                   original_numbers,
                                                   predicted_signs):
                        answer_json['numbers'].append({
                            'span': offset,
                            'value': value,
                            'sign': sign
                        })
                    if number_indices[-1] == -1:
                        # There is a dummy 0 number at position -1 added in some cases; we are
                        # removing that here.
                        answer_json["numbers"].pop()
                    answer_json["value"] = result
                elif predicted_ability_str == "counting":
                    answer_json["answer_type"] = "count"
                    predicted_count = best_count_number[i].detach().cpu(
                    ).numpy()
                    predicted_answer = str(predicted_count)
                    answer_json["count"] = predicted_count
                else:
                    raise ValueError(
                        f"Unsupported answer ability: {predicted_ability_str}")

                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(answer_json)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)
            # This is used for the demo.
            output_dict[
                "passage_question_attention"] = passage_question_attention
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict
    def forward(self,  # type: ignore
                question_passage: Dict[str, torch.LongTensor],
                number_indices: torch.LongTensor,
                mask_indices: torch.LongTensor,
                num_spans: torch.LongTensor = None,
                answer_as_passage_spans: torch.LongTensor = None,
                answer_as_question_spans: torch.LongTensor = None,
                answer_as_expressions: torch.LongTensor = None,
                answer_as_expressions_extra: torch.LongTensor = None,
                answer_as_counts: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # Shape: (batch_size, seqlen)
        question_passage_tokens = question_passage["tokens"]
        # Shape: (batch_size, seqlen)
        pad_mask = question_passage["mask"] 
        # Shape: (batch_size, seqlen)
        seqlen_ids = question_passage["tokens-type-ids"]
        
        max_seqlen = question_passage_tokens.shape[-1]
        batch_size = question_passage_tokens.shape[0]
                
        # Shape: (batch_size, 3)
        mask = mask_indices.squeeze(-1)
        # Shape: (batch_size, seqlen)
        cls_sep_mask = \
            torch.ones(pad_mask.shape, device=pad_mask.device).long().scatter(1, mask, torch.zeros(mask.shape, device=mask.device).long())
        # Shape: (batch_size, seqlen)
        passage_mask = seqlen_ids * pad_mask * cls_sep_mask
        # Shape: (batch_size, seqlen)
        question_mask = (1 - seqlen_ids) * pad_mask * cls_sep_mask
        
        # Shape: (batch_size, seqlen, bert_dim)
        bert_out, _ = self.BERT(question_passage_tokens, seqlen_ids, pad_mask, output_all_encoded_layers=False)
        # Shape: (batch_size, qlen, bert_dim)
        question_end = max(mask[:,1])
        question_out = bert_out[:,:question_end]
        # Shape: (batch_size, qlen)
        question_mask = question_mask[:,:question_end]
        # Shape: (batch_size, out)
        question_vector = self.summary_vector(question_out, question_mask, "question")
        
        passage_out = bert_out
        del bert_out
        
        # Shape: (batch_size, bert_dim)
        passage_vector = self.summary_vector(passage_out, passage_mask)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(torch.cat([passage_vector, question_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            count_passage_vector = self.summary_vector(passage_out, passage_mask, "count_passage")
            count_number_log_probs, best_count_number = self._count_module(count_passage_vector)

        if "passage_span_extraction" in self.answering_abilities:
            passage_span_start_log_probs, passage_span_end_log_probs, best_passage_span = \
                self._passage_span_module(passage_out, passage_mask)

        if "question_span_extraction" in self.answering_abilities:
            qspan_passage_vector = self.summary_vector(passage_out, passage_mask, "qspan_passage")
            question_span_start_log_probs, question_span_end_log_probs, best_question_span = \
                self._question_span_module(qspan_passage_vector, question_out, question_mask)
            
        if "arithmetic" in self.answering_abilities:
            arithmetic_passage_vector = self.summary_vector(passage_out, passage_mask, "arithmetic_passage")
            arithmetic_question_vector = self.summary_vector(question_out, question_mask, "arithmetic_question")
            
            arithmetic_template_logits = \
                self._arithmetic_template_predictor(torch.cat([arithmetic_passage_vector, arithmetic_question_vector], -1))
            arithmetic_template_log_probs = arithmetic_template_logits.log_softmax(-1)
            arithmetic_best_templates = arithmetic_template_log_probs.argmax(-1)
            
            number_mask = (number_indices[:,:,0].long() != -1).long()
            
            arithmetic_template_slot_log_probs, arithmetic_best_template_slots, number_mask = \
                self._arithmetic_module(arithmetic_passage_vector, passage_out, number_indices, number_mask)

            
        output_dict = {}
        del passage_out, question_out
        # If answer is given, compute the loss.
        if answer_as_passage_spans is not None or answer_as_question_spans is not None \
                or answer_as_expressions is not None or answer_as_counts is not None:

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    log_marginal_likelihood_for_passage_span = \
                        self._passage_span_log_likelihood(answer_as_passage_spans,
                                                          passage_span_start_log_probs,
                                                          passage_span_end_log_probs)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span)

                elif answering_ability == "question_span_extraction":
                    log_marginal_likelihood_for_question_span = \
                        self._question_span_log_likelihood(answer_as_question_spans,
                                                           question_span_start_log_probs,
                                                           question_span_end_log_probs)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_question_span)

                elif answering_ability == "arithmetic":
                    log_marginal_likelihood_for_arithmetic = \
                        self._arithmetic_log_likelihood(answer_as_expressions,
                                                        arithmetic_template_slot_log_probs, 
                                                        arithmetic_template_log_probs)                                  
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_arithmetic)

                elif answering_ability == "counting":
                    log_marginal_likelihood_for_count = \
                        self._count_log_likelihood(answer_as_counts, 
                                                   count_number_log_probs)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_count)

                else:
                    raise ValueError(f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = all_log_marginal_likelihoods + answer_ability_log_probs
                marginal_log_likelihood = util.logsumexp(all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]
        
            output_dict["loss"] = - marginal_log_likelihood.mean()
        with torch.no_grad():
            # Compute the metrics and add the tokenized input to the output.
            if metadata is not None:
                output_dict["question_id"] = []
                output_dict["answer"] = []
                question_tokens = []
                passage_tokens = []
                for i in range(batch_size):
                    if len(self.answering_abilities) > 1:
                        predicted_ability_str = self.answering_abilities[best_answer_ability[i]]
                    else:
                        predicted_ability_str = self.answering_abilities[0]
                    answer_json: Dict[str, Any] = {}

                    # We did not consider multi-mention answers here
                    if predicted_ability_str == "passage_span_extraction":
                        answer_json["answer_type"] = "passage_span"
                        answer_json["value"], answer_json["spans"] = \
                            self._span_prediction(question_passage_tokens[i], best_passage_span[i])
                    elif predicted_ability_str == "question_span_extraction":
                        answer_json["answer_type"] = "question_span"
                        answer_json["value"], answer_json["spans"] = \
                            self._span_prediction(question_passage_tokens[i], best_question_span[i])
                    elif predicted_ability_str == "arithmetic":  
                        answer_json["answer_type"] = "arithmetic"
                        original_numbers = metadata[i]['original_numbers']
                        answer_json["value"], answer_json["indices"], answer_json["numbers"] = \
                            self._arithmetic_prediction(original_numbers, 
                                                             arithmetic_best_templates[i],
                                                             arithmetic_best_template_slots[i])
                        answer_json['template'] = arithmetic_best_templates[i].item()
                    elif predicted_ability_str == "counting":
                        answer_json["answer_type"] = "count"
                        answer_json["value"], answer_json["count"] = \
                            self._count_prediction(best_count_number[i])
                    else:
                        raise ValueError(f"Unsupported answer ability: {predicted_ability_str}")

                    output_dict["question_id"].append(metadata[i]["question_id"])
                    output_dict["answer"].append(answer_json)
                    answer_annotations = metadata[i].get('answer_annotations', [])
                    if answer_annotations:
                        self._drop_metrics(answer_json["value"], answer_annotations)

        return output_dict
Exemple #19
0
    def _get_ll_contrib(
            self, generation_scores: torch.Tensor,
            generation_scores_mask: torch.Tensor, copy_scores: torch.Tensor,
            target_tokens: torch.Tensor, target_to_source: torch.Tensor,
            copy_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get the log-likelihood contribution from a single timestep.

        Parameters
        ----------
        generation_scores : ``torch.Tensor``
            Shape: `(batch_size, target_vocab_size)`
        generation_scores_mask : ``torch.Tensor``
            Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's.
        copy_scores : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`
        target_tokens : ``torch.Tensor``
            Shape: `(batch_size,)`
        target_to_source : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`
        copy_mask : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Shape: `(batch_size,), (batch_size, max_input_sequence_length)`
        """
        _, target_size = generation_scores.size()

        # The point of this mask is to just mask out all source token scores
        # that just represent padding. We apply the mask to the concatenation
        # of the generation scores and the copy scores to normalize the scores
        # correctly during the softmax.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        mask = torch.cat((generation_scores_mask, copy_mask), dim=-1)
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
        # Normalize generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        log_probs = util.masked_log_softmax(all_scores, mask)
        # Calculate the log probability (`copy_log_probs`) for each token in the source sentence
        # that matches the current target token. We use the sum of these copy probabilities
        # for matching tokens in the source sentence to get the total probability
        # for the target token. We also need to normalize the individual copy probabilities
        # to create `selective_weights`, which are used in the next timestep to create
        # a selective read state.
        # shape: (batch_size, trimmed_source_length)
        copy_log_probs = log_probs[:, target_size:] + (
            target_to_source.float() + 1e-45).log()
        # Since `log_probs[:, target_size]` gives us the raw copy log probabilities,
        # we use a non-log softmax to get the normalized non-log copy probabilities.
        selective_weights = util.masked_softmax(log_probs[:, target_size:],
                                                target_to_source)
        # This mask ensures that item in the batch has a non-zero generation probabilities
        # for this timestep only when the gold target token is not OOV or there are no
        # matching tokens in the source sentence.
        # shape: (batch_size, 1)
        gen_mask = ((target_tokens != self._oov_index) |
                    (target_to_source.sum(-1) == 0)).float()
        log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
        # Now we get the generation score for the gold target token.
        # shape: (batch_size, 1)
        print(target_tokens.unsqueeze(1), target_tokens.unsqueeze(1).size())
        print(log_probs, log_probs.size())
        print(log_probs.gather(1, target_tokens.unsqueeze(1)),
              log_probs.gather(1, target_tokens.unsqueeze(1)).size())
        generation_log_probs = log_probs.gather(
            1, target_tokens.unsqueeze(1)) + log_gen_mask
        # ... and add the copy score to get the step log likelihood.
        # shape: (batch_size, 1 + trimmed_source_length)
        combined_gen_and_copy = torch.cat(
            (generation_log_probs, copy_log_probs), dim=-1)
        # shape: (batch_size,)
        step_log_likelihood = util.logsumexp(combined_gen_and_copy)

        return step_log_likelihood, selective_weights
Exemple #20
0
    def forward(
            self,  # type: ignore
            passage_question: Dict[str, torch.LongTensor],
            number_indices: torch.LongTensor,
            answer_type=None,
            answer_as_spans: torch.LongTensor = None,
            answer_as_add_sub_expressions: torch.LongTensor = None,
            answer_as_counts: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        passage_question_mask = passage_question["mask"].float()
        embedded_passage_question = self._dropout(
            self._text_field_embedder(passage_question))  # Encode with bert

        batch_size = embedded_passage_question.size(0)

        encoded_passage_question = embedded_passage_question
        """
        passage_vactor 用 [CLS]对应的代替
        """

        passage_question_vector = encoded_passage_question[:, 0]

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = \
                self._answer_ability_predictor(passage_question_vector)
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            #best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(
                passage_question_vector)
            count_number_log_probs = torch.nn.functional.log_softmax(
                count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = \
                torch.gather(count_number_log_probs, 1, best_count_number.unsqueeze(-1)).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self.
                                                                _counting_index]

        if "span_extraction" in self.answering_abilities:
            # Shape: (batch_size, passage_length)
            span_start_logits = self._span_start_predictor(
                encoded_passage_question).squeeze(-1)
            # Shape: (batch_size, passage_length)
            span_end_logits = self._span_end_predictor(
                encoded_passage_question).squeeze(-1)
            # Shape: (batch_size, passage_length)
            span_start_log_probs = util.masked_log_softmax(
                span_start_logits, passage_question_mask)
            span_end_log_probs = util.masked_log_softmax(
                span_end_logits, passage_question_mask)

            # Info about the best passage span prediction
            span_start_logits = util.replace_masked_values(
                span_start_logits, passage_question_mask, -1e7)
            span_end_logits = util.replace_masked_values(
                span_end_logits, passage_question_mask, -1e7)
            # Shape: (batch_size, 2)
            best_span = get_best_span(span_start_logits, span_end_logits)
            # Shape: (batch_size, 2)
            best_start_log_probs = \
                torch.gather(span_start_log_probs, 1, best_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_end_log_probs = \
                torch.gather(span_end_log_probs, 1, best_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_span_log_prob = best_start_log_probs + best_end_log_probs
            if len(self.answering_abilities) > 1:
                best_span_log_prob += answer_ability_log_probs[:, self.
                                                               _span_extraction_index]

        if "addition_subtraction" in self.answering_abilities:
            # Shape: (batch_size, # of numbers in the passage)
            number_indices = number_indices.squeeze(-1)
            number_mask = (number_indices != -1).long()

            clamped_number_indices = util.replace_masked_values(
                number_indices, number_mask, 0)
            #encoded_passage_for_numbers = torch.cat([modeled_passage_list[0], modeled_passage_list[3]], dim=-1)
            # Shape: (batch_size, # of numbers in the passage, encoding_dim)
            encoded_numbers = torch.gather(
                encoded_passage_question, 1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_passage_question.size(-1)))

            #self._external_number_embedding = self._external_number_embedding.cuda(device)

            #encoded_numbers = self.self_attention(encoded_numbers,number_mask)
            encoded_numbers = self.Concat_attention(encoded_numbers,
                                                    passage_question_vector,
                                                    number_mask)
            # Shape: (batch_size, # of numbers in the passage)
            #encoded_numbers = torch.cat(
            #        [encoded_numbers, passage_question_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1)], -1)

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(
                number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(
                best_signs_for_numbers, number_mask, 0)
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                number_sign_log_probs, 2,
                best_signs_for_numbers.unsqueeze(-1)).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(
                best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)

            if len(self.answering_abilities) > 1:
                # batch_size
                best_combination_log_prob = best_signs_log_probs.sum(-1)
                best_combination_log_prob += answer_ability_log_probs[:, self.
                                                                      _addition_subtraction_index]

        best_answer_ability = torch.argmax(
            torch.stack([
                best_span_log_prob, best_combination_log_prob,
                best_count_log_prob
            ], -1), 1)

        output_dict = {}

        # If answer is given, compute the loss.
        if answer_as_spans is not None or answer_as_add_sub_expressions is not None or answer_as_counts is not None:

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_span_starts = answer_as_spans[:, :, 0]
                    gold_span_ends = answer_as_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_span_mask = (gold_span_starts != -1).long()
                    clamped_gold_span_starts = \
                        util.replace_masked_values(gold_span_starts, gold_span_mask, 0)
                    clamped_gold_span_ends = \
                        util.replace_masked_values(gold_span_ends, gold_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_span_starts = \
                        torch.gather(span_start_log_probs, 1, clamped_gold_span_starts)
                    log_likelihood_for_span_ends = \
                        torch.gather(span_end_log_probs, 1, clamped_gold_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_spans = \
                        log_likelihood_for_span_starts + log_likelihood_for_span_ends
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_spans = \
                        util.replace_masked_values(log_likelihood_for_spans, gold_span_mask, -1e7)
                    # Shape: (batch_size, )
                    #                    log_marginal_likelihood_for_span = torch.sum(log_likelihood_for_spans,-1)
                    log_marginal_likelihood_for_span = util.logsumexp(
                        log_likelihood_for_spans)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_span)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = (answer_as_add_sub_expressions.sum(-1)
                                         > 0).float()
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(
                        1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(
                        number_sign_log_probs, 2, gold_add_sub_signs)
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = \
                        util.replace_masked_values(log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0)
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(
                        1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = \
                        util.replace_masked_values(log_likelihood_for_add_subs, gold_add_sub_mask, -1e7)
                    # Shape: (batch_size, )

                    #log_marginal_likelihood_for_add_sub =  torch.sum(log_likelihood_for_add_subs,-1)
                    #log_marginal_likelihood_for_add_sub = util.logsumexp(log_likelihood_for_add_subs)
                    #log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub)

                    log_marginal_likelihood_for_add_sub = util.logsumexp(
                        log_likelihood_for_add_subs)

                    #log_marginal_likelihood_for_external = util.logsumexp(log_likelihood_for_externals)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = (answer_as_counts != -1).long()
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(
                        answer_as_counts, gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(
                        count_number_log_probs, 1, clamped_gold_counts)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = \
                        util.replace_masked_values(log_likelihood_for_counts, gold_count_mask, -1e7)
                    # Shape: (batch_size, )
                    #log_marginal_likelihood_for_count =  torch.sum(log_likelihood_for_counts,-1)
                    log_marginal_likelihood_for_count = util.logsumexp(
                        log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_count)

                else:
                    raise ValueError(
                        f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                loss_for_type = -(torch.sum(
                    answer_ability_log_probs * answer_type, -1)).mean()
                loss_for_answer = -(torch.sum(all_log_marginal_likelihoods,
                                              -1)).mean()
                loss = loss_for_type + loss_for_answer
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]
                loss = -marginal_log_likelihood.mean()
            output_dict["loss"] = loss

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            passage_question_tokens = []
            for i in range(batch_size):
                passage_question_tokens.append(
                    metadata[i]['passage_question_tokens'])

                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[
                        best_answer_ability[i].detach().cpu().numpy()]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                answer_json: Dict[str, Any] = {}

                # We did not consider multi-mention answers here
                if predicted_ability_str == "span_extraction":
                    answer_json["answer_type"] = "span"
                    passage_question_token = metadata[i][
                        'passage_question_tokens']
                    #offsets = metadata[i]['passage_token_offsets']
                    predicted_span = tuple(best_span[i].detach().cpu().numpy())
                    start_offset = predicted_span[0]
                    end_offset = predicted_span[1]
                    predicted_answer = " ".join([
                        token for token in
                        passage_question_token[start_offset:end_offset + 1]
                        if token != "[SEP]"
                    ]).replace(" ##", "")
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif predicted_ability_str == "counting":
                    answer_json["answer_type"] = "count"
                    predicted_count = best_count_number[i].detach().cpu(
                    ).numpy()
                    predicted_answer = str(predicted_count)
                    answer_json["count"] = predicted_count
                elif predicted_ability_str == "addition_subtraction":
                    answer_json["answer_type"] = "arithmetic"
                    original_numbers = metadata[i]['original_numbers']
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [
                        sign_remap[it] for it in
                        best_signs_for_numbers[i].detach().cpu().numpy()
                    ]
                    result = 0
                    for j, number in enumerate(original_numbers):
                        sign = predicted_signs[j]
                        if sign != 0:
                            result += sign * number

                    predicted_answer = str(result)
                    #offsets = metadata[i]['passage_token_offsets']
                    number_indices = metadata[i]['number_indices']
                    #number_positions = [offsets[index] for index in number_indices]
                    answer_json['numbers'] = []
                    for indice, value, sign in zip(number_indices,
                                                   original_numbers,
                                                   predicted_signs):
                        answer_json['numbers'].append({
                            'span': indice,
                            'value': str(value),
                            'sign': sign
                        })
                    if number_indices[-1] == -1:
                        # There is a dummy 0 number at position -1 added in some cases; we are
                        # removing that here.
                        answer_json["numbers"].pop()
                    answer_json["value"] = str(result)
                else:
                    raise ValueError(
                        f"Unsupported answer ability: {predicted_ability_str}")

                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(answer_json)
                answer_annotations = metadata[i].get('answer_annotations', [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)
            # This is used for the demo.
            #output_dict["passage_question_attention"] = passage_question_attention
            output_dict["passage_question_tokens"] = passage_question_tokens
            #output_dict["passage_tokens"] = passage_tokens
        return output_dict
Exemple #21
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Exemple #22
0
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            spans: torch.IntTensor,
            span_labels: torch.IntTensor = None,
            keep_antecedent_alternatives: Optional[ScatterableList] = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        output_alternative_antecedents: ``Optional[ScatterableList]`` - if non-`None` and
            any contained value is ``True``, the output dictionary will contain antecedent
            scores and antecedent_mask (see below).

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        antecedent_scores : ``torch.FloatTensor``, optional
            A tensor of shape ``(batch_size, num_spans_to_keep, max_antecedents+1)`` giving
            the antecedent scores for each mention.  Each i-th batch element is associated with a
            matrix whose the j-th row contains the antecedent scores for the j-th mention of that
            batch (corresponding to top_spans), and k-th column contains the score for
            the antecedent_indices[k - 1]-th mention being the antecedent of the j-th mention.
            The first column (index k = 0) contains the score for the j-th mention having
            no antecedent.
        antecedent_mask : ``torch.FloatTensor``, optional
            A tensor of shape ``(batch_size, num_spans_to_keep, max_antecedent)``.  The (i, j)-th
            entry will be 1 if the (i, j)-th entry of `antecedent_scores` gives valid antecedent
            score and 0 otherwise. This is necessary because, for example, the first mention of a
            document has no antecedents to score.

        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        # In order to get the correct shape, we collapse the last dimension (it should only be one
        # index long). We then reshape it to make sure the shape is correct in edge cases (namely
        # when there is exactly one input mention).
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).reshape((spans.shape[0], spans.shape[1])) \
            .float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = min(
            num_spans, int(math.floor(self._spans_per_word * document_length)))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        batch_size = top_spans.shape[0]
        output_dict = {
            "top_spans":
            top_spans,
            # because antecedent_indices is the same for all batch elements (since it
            # doesn't depend on the batch content), we need to expand it to have
            # batch_size as its first dimension or else model.forward_on_instances
            # will discard it
            "antecedent_indices":
            valid_antecedent_indices.expand(batch_size,
                                            valid_antecedent_indices.shape[0],
                                            valid_antecedent_indices.shape[1]),
            "predicted_antecedents":
            predicted_antecedents
        }
        if keep_antecedent_alternatives and any(keep_antecedent_alternatives):
            output_dict["antecedent_scores"] = coreference_scores
            output_dict["antecedent_mask"] = (valid_antecedent_log_mask >= 0)

        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict