Ejemplo n.º 1
0
def nd_batched_padded_index_select(target: torch.Tensor,
                                   indices: torch.IntTensor,
                                   padding_value: Any = -1):
    """
    Multidimensional version of `util.batched_index_select`,
    supports also padded indices: for padded indices the corresponding
    return values are padded with `padding_value`.
    """
    mask = (indices >= 0).unsqueeze(-1).float()
    indices = indices.clamp(min=0)

    res = nd_batched_index_select(target, indices)

    return util.replace_masked_values(res, mask, padding_value)
Ejemplo n.º 2
0
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        metadata: List[Dict[str, Any]],
        doc_span_offsets: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        doc_truth_spans: torch.IntTensor = None,
        doc_spans_in_truth: torch.IntTensor = None,
        doc_relation_labels: torch.Tensor = None,
        truth_spans: List[Set[Tuple[int, int]]] = None,
        doc_relations=None,
        doc_ner_labels: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:  # add matrix from datareader
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        doc_ner_labels : ``torch.IntTensor``.
            A tensor of shape # TODO,
            ...
        doc_span_offsets : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_truth_spans : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, 1),
            ...
        doc_spans_in_truth : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_relation_labels : ``torch.Tensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans),
            ...

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        batch_size = len(spans)
        document_length = text_embeddings.size(1)
        max_sentence_length = max(
            len(sentence) for document in metadata
            for sentence in document['doc_tokens'])
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # TODO features dropout
        # Shape: (batch_size, num_spans, embedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_relex_spans_to_keep = int(
            math.floor(self._relex_spans_per_word * max_sentence_length))

        # Shapes:
        # (batch_size, num_spans_to_keep, span_dim),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # Shape: (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = dict()

        output_dict["top_spans"] = top_spans
        output_dict["antecedent_indices"] = valid_antecedent_indices
        output_dict["predicted_antecedents"] = predicted_antecedents

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]

        # Shape: (,)
        loss = 0

        # Shape: (batch_size, max_sentences, max_spans)
        doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float()
        # Shape: (batch_size, max_sentences, num_spans, span_dim)
        doc_span_embeddings = util.batched_index_select(
            span_embeddings,
            doc_span_offsets.squeeze(-1).long().clamp(min=0))

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        pruned = self._relex_mention_pruner(
            doc_span_embeddings,
            doc_span_mask,
            num_items_to_keep=num_relex_spans_to_keep,
            pass_through=['num_items_to_keep'])
        (top_relex_span_embeddings, top_relex_span_mask,
         top_relex_span_indices, top_relex_span_mention_scores) = pruned

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        top_relex_span_mask = top_relex_span_mask.unsqueeze(-1)

        # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2)  # TODO do we need for a mask?
        doc_spans = util.batched_index_select(
            spans,
            doc_span_offsets.clamp(0).squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2)
        top_relex_spans = nd_batched_index_select(doc_spans,
                                                  top_relex_span_indices)

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep).
        (relex_span_pair_embeddings,
         relex_span_pair_mask) = self._compute_relex_span_pair_embeddings(
             top_relex_span_embeddings, top_relex_span_mask.squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels)
        relex_scores = self._compute_relex_scores(
            relex_span_pair_embeddings, top_relex_span_mention_scores)
        output_dict['relex_scores'] = relex_scores
        output_dict['top_relex_spans'] = top_relex_spans

        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)
            antecedent_labels_ = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long(
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability x to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs)
            negative_marginal_log_likelihood *= top_span_mask.squeeze(
                -1).float()
            negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum(
            )

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            coref_loss = negative_marginal_log_likelihood
            output_dict['coref_loss'] = coref_loss
            loss += self._loss_coref_weight * coref_loss

        if doc_relations is not None:

            # The adjacency matrix for relation extraction is very sparse.
            # As it is not just sparse, but row/column sparse (only few
            # rows and columns are non-zero and in that case these rows/columns
            # are not sparse), we implemented our own matrix for the case.
            # Here we have indices of truth spans and mapping, using which
            # we map prediction matrix on truth matrix.
            # TODO Add teacher forcing support.

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep),
            relative_indices = top_relex_span_indices
            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1),
            compressed_indices = nd_batched_padded_index_select(
                doc_spans_in_truth, relative_indices)

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans)
            gold_pruned_rows = nd_batched_padded_index_select(
                doc_relation_labels,
                compressed_indices.squeeze(-1),
                padding_value=0)
            gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3,
                                                        2).contiguous()

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep)
            gold_pruned_matrices = nd_batched_padded_index_select(
                gold_pruned_rows,
                compressed_indices.squeeze(-1),
                padding_value=0)  # pad with epsilon
            gold_pruned_matrices = gold_pruned_matrices.permute(
                0, 1, 3, 2).contiguous()

            # TODO log_mask relex score before passing
            relex_loss = nd_cross_entropy_with_logits(relex_scores,
                                                      gold_pruned_matrices,
                                                      relex_span_pair_mask)
            output_dict['relex_loss'] = relex_loss

            self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2),
                                       truth_spans)
            self._compute_relex_metrics(output_dict, doc_relations)

            loss += self._loss_relex_weight * relex_loss

        if doc_ner_labels is not None:
            # Shape: (batch_size, max_sentences, num_spans, num_ner_classes)
            ner_scores = self._ner_scorer(doc_span_embeddings)
            output_dict['ner_scores'] = ner_scores

            ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels,
                                                    doc_span_mask)
            output_dict['ner_loss'] = ner_loss
            loss += self._loss_ner_weight * ner_loss

        if not isinstance(loss, int):  # If loss is not yet modified
            output_dict["loss"] = loss

        return output_dict