Exemplo n.º 1
0
    def forward(
        self,
        span: torch.Tensor,  # SHAPE: (batch_size, num_spans, span_dim)
        span_pairs: torch.LongTensor  # SHAPE: (batch_size, num_span_pairs)
    ):
        span1 = span2 = span
        if self.dim_reduce_layer1 is not None:
            span1 = self.dim_reduce_layer1(span)
        if self.dim_reduce_layer2 is not None:
            span2 = self.dim_reduce_layer2(span)

        if not self.pair:
            return span1, span2

        num_spans = span.size(1)

        # get span pair embedding
        span_pairs_p = span_pairs[:, :, 0]
        span_pairs_c = span_pairs[:, :, 1]
        # SHAPE: (batch_size * num_span_pairs)
        flat_span_pairs_p = util.flatten_and_batch_shift_indices(
            span_pairs_p, num_spans)
        flat_span_pairs_c = util.flatten_and_batch_shift_indices(
            span_pairs_c, num_spans)
        # SHAPE: (batch_size, num_span_pairs, span_dim)
        span_pair_p_emb = util.batched_index_select(span1, span_pairs_p,
                                                    flat_span_pairs_p)
        span_pair_c_emb = util.batched_index_select(span2, span_pairs_c,
                                                    flat_span_pairs_c)
        if self.combine == 'concat':
            # SHAPE: (batch_size, num_span_pairs, span_dim * 2)
            span_pair_emb = torch.cat([span_pair_p_emb, span_pair_c_emb], -1)
        elif self.combine == 'coref':
            # use the indices gap as distance, which requires the indices to be consistent
            # with the order they appear in the sentences
            distance = span_pairs_p - span_pairs_c
            # SHAPE: (batch_size, num_span_pairs, dist_emb_dim)
            distance_embeddings = self.distance_embedding(
                util.bucket_values(
                    distance, num_total_buckets=self.num_distance_buckets))
            # SHAPE: (batch_size, num_span_pairs, span_dim * 3)
            span_pair_emb = torch.cat([
                span_pair_p_emb, span_pair_c_emb,
                span_pair_p_emb * span_pair_c_emb, distance_embeddings
            ], -1)

        if self.repr_layer is not None:
            # SHAPE: (batch_size, num_span_pairs, out_dim)
            span_pair_emb = self.repr_layer(span_pair_emb)

        return span_pair_emb
Exemplo n.º 2
0
    def forward(
            self,
            sequence_tensor: torch.FloatTensor,
            span_indices: torch.LongTensor,
            span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(
            max_batch_span_width,
            util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(
            raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor,
                                                    span_indices,
                                                    flat_span_indices)

        #  text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        batch_size, num_spans, max_batch_span_width, _ = span_embeddings.size()

        view_text_embeddings = span_embeddings.view(batch_size * num_spans,
                                                    max_batch_span_width, -1)
        span_mask = span_mask.view(batch_size * num_spans,
                                   max_batch_span_width)
        cnn_text_embeddings = self.cnn(view_text_embeddings, span_mask)
        cnn_text_embeddings = cnn_text_embeddings.view(batch_size, num_spans,
                                                       self._output_dim)
        return cnn_text_embeddings
    def prune_top_spans(
        self,
        mention_scores_output,
        spans,
    ):
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             **mention_scores_output)

        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, spans.size(1))
        return (
            top_span_embeddings,
            top_span_mask,
            top_span_indices,
            top_span_mention_scores,
            flat_top_span_indices,
        )
Exemplo n.º 4
0
    def inference_coref(self, batch, embedded_text_input_relation, mask):
        submodel = self.model._tagger_coref

        ### Fast inference of coreference ###
        spans = batch["spans"]

        document_length = mask.size(1)
        num_spans = spans.size(1)

        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        spans = F.relu(spans.float()).long()

        encoded_text_coref = submodel._context_layer(
            embedded_text_input_relation, mask)
        endpoint_span_embeddings = submodel._endpoint_span_extractor(
            encoded_text_coref, spans)
        attended_span_embeddings = submodel._attentive_span_extractor(
            embedded_text_input_relation, spans)

        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        num_spans_to_keep = int(
            math.floor(submodel._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = submodel._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        max_antecedents = min(submodel._max_antecedents, num_spans_to_keep)

        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            submodel._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(mask))
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        span_pair_embeddings = submodel._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        coreference_scores = submodel._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        _, predicted_antecedents = coreference_scores.max(2)
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents
        }

        return output_dict
Exemplo n.º 5
0
 def test_flatten_and_batch_shift_indices(self):
     indices = numpy.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 9, 9, 9]],
                            [[2, 1, 0, 7], [7, 7, 2, 3], [0, 0, 4, 2]]])
     indices = Variable(torch.LongTensor(indices))
     shifted_indices = util.flatten_and_batch_shift_indices(indices, 10)
     numpy.testing.assert_array_equal(
         shifted_indices.data.numpy(),
         numpy.array([
             1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9, 12, 11, 10, 17, 17, 17, 12,
             13, 10, 10, 14, 12
         ]))
Exemplo n.º 6
0
 def test_flatten_and_batch_shift_indices(self):
     indices = numpy.array([[[1, 2, 3, 4],
                             [5, 6, 7, 8],
                             [9, 9, 9, 9]],
                            [[2, 1, 0, 7],
                             [7, 7, 2, 3],
                             [0, 0, 4, 2]]])
     indices = torch.tensor(indices, dtype=torch.long)
     shifted_indices = util.flatten_and_batch_shift_indices(indices, 10)
     numpy.testing.assert_array_equal(shifted_indices.data.numpy(),
                                      numpy.array([1, 2, 3, 4, 5, 6, 7, 8, 9,
                                                   9, 9, 9, 12, 11, 10, 17, 17,
                                                   17, 12, 13, 10, 10, 14, 12]))
Exemplo n.º 7
0
    def _prune_spans(self, spans, span_mask, span_embeddings, sentence_lengths):
        # Prune
        num_spans = spans.size(1)  # Max number of spans for the minibatch.

        # Keep different number of spans for each minibatch entry.
        num_spans_to_keep = torch.ceil(sentence_lengths.float() * self._spans_per_word).long()

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores, num_spans_kept) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)

        top_span_mask = top_span_mask.unsqueeze(-1)

        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        return top_span_embeddings, top_span_mention_scores, num_spans_to_keep, top_span_mask, top_span_indices, top_spans
Exemplo n.º 8
0
def generate_embeddings_for_pooling(sequence_tensor, span_starts, span_ends):
    #(B, L, E), #(B, L), #(B, L)
    span_starts = span_starts.unsqueeze(-1)
    span_ends = (span_ends - 1).unsqueeze(-1)
    span_widths = span_ends - span_starts
    max_batch_span_width = span_widths.max().item() + 1

    # Shape: (1, 1, max_batch_span_width)
    max_span_range_indices = util.get_range_vector(
        max_batch_span_width,
        util.get_device_of(sequence_tensor)).view(1, 1, -1)
    # Shape: (batch_size, num_spans, max_batch_span_width)
    # This is a broadcasted comparison - for each span we are considering,
    # we are creating a range vector of size max_span_width, but masking values
    # which are greater than the actual length of the span.
    #
    # We're using <= here (and for the mask below) because the span ends are
    # inclusive, so we want to include indices which are equal to span_widths rather
    # than using it as a non-inclusive upper bound.
    span_mask = (max_span_range_indices <= span_widths).float()
    raw_span_indices = span_ends - max_span_range_indices
    # We also don't want to include span indices which are less than zero,
    # which happens because some spans near the beginning of the sequence
    # have an end index < max_batch_span_width, so we add this to the mask here.
    span_mask = span_mask * (raw_span_indices >= 0).float()
    span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

    # Shape: (batch_size * num_spans * max_batch_span_width)
    flat_span_indices = util.flatten_and_batch_shift_indices(
        span_indices, sequence_tensor.size(1))

    # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
    span_embeddings = util.batched_index_select(sequence_tensor, span_indices,
                                                flat_span_indices)

    return span_embeddings, span_mask
Exemplo n.º 9
0
    def forward(
        self,  # type: ignore
        source_spans: torch.IntTensor,
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for source sentence representation.
            Comes from a ``ListField[SpanField]`` of indices into the source sentence.
        source_tokens : Dict[str, torch.LongTensor]
           The output of ``TextField.as_array()`` applied on the source ``TextField``. This will be
           passed through a ``TextFieldEmbedder`` and then through an encoder.
        target_tokens : Dict[str, torch.LongTensor], optional (default = None)
           Output of ``Textfield.as_array()`` applied on target ``TextField``. We assume that the
           target tokens are also represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, encoder_output_dim)
        embedded_input = self._source_embedder(source_tokens)

        num_spans = source_spans.size(1)
        source_length = embedded_input.size(1)
        batch_size, _, _ = embedded_input.size()

        # (batch_size, source_length)
        source_mask = get_text_field_mask(source_tokens)

        # Shape: (batch_size, num_spans)
        span_mask = (source_spans[:, :, 0] >= 0).squeeze(-1).float()

        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(source_spans.float()).long()

        # Contextualized word embeddings; Shape: (batch_size, source_length, embedding_dim)
        contextualized_word_embeddings = self._encoder(embedded_input,
                                                       source_mask)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        span_embeddings = self._span_extractor(contextualized_word_embeddings,
                                               spans)

        # Prune based on feedforward scorer
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * source_length))

        # Shape: see return section of SpanPruner docs
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_scores) = self._span_pruner(span_embeddings, span_mask,
                                              num_spans_to_keep)

        # Shape: (batch_size * num_spans_to_keep)
        flat_top_span_indices = flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = batched_index_select(spans, top_span_indices,
                                         flat_top_span_indices)

        # Here we define what we will init first hidden state of decoder with
        summary_of_encoded_source = contextualized_word_embeddings[:,
                                                                   -1]  # (batch_size, encoder_output_dim)
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end symbol. Either way, we
            # don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Condition decoder on encoder
        # Here we just derive and append one more dummy embedding feature (sum) to match dimensions later
        # Shape: (batch_size, encoder_output_dim + 1)
        decoder_hidden = torch.cat(
            (summary_of_encoded_source,
             summary_of_encoded_source.sum(1).unsqueeze(1)), 1)
        decoder_context = Variable(top_span_embeddings.data.new().resize_(
            batch_size, self._decoder_output_dim).fill_(0))
        last_predictions = None
        step_logits = []
        step_probabilities = []
        step_predictions = []
        step_attention_weights = []
        for timestep in range(num_decoding_steps):
            if self.training and all(
                    torch.rand(1) >= self._scheduled_sampling_ratio):
                input_choices = targets[:, timestep]
            else:
                if timestep == 0:
                    # For the first timestep, when we do not have targets, we input start symbols.
                    # (batch_size,)
                    input_choices = Variable(
                        source_mask.data.new().resize_(batch_size).fill_(
                            self._start_index))
                else:
                    input_choices = last_predictions

            # We append span scores to the span embedding features to make SpanPrune trainable
            # Shape: (batch_size, num_spans_to_keep, span_embedding_dim + 1)
            top_span_embeddings_scores = torch.cat(
                (top_span_embeddings, top_span_scores), 2)
            # Shape: (batch_size, decoder_input_dim)
            decoder_input, attention_weights = self._prepare_decode_step_input(
                input_choices, decoder_hidden, top_span_embeddings_scores,
                top_span_mask)

            if attention_weights is not None:
                step_attention_weights.append(attention_weights)

            # Shape: both (batch_size, decoder_output_dim),
            decoder_hidden, decoder_context = self._decoder_cell(
                decoder_input, (decoder_hidden, decoder_context))
            # (batch_size, num_classes)
            output_projections = self._output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            step_probabilities.append(class_probabilities.unsqueeze(1))
            last_predictions = predicted_classes
            # (batch_size, 1)
            step_predictions.append(last_predictions.unsqueeze(1))
        # step_logits is a list containing tensors of shape (batch_size, 1, num_classes)
        # This is (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        class_probabilities = torch.cat(step_probabilities, 1)
        all_predictions = torch.cat(step_predictions, 1)

        # step_attention_weights is a list containing tensors of shape (batch_size, num_encoder_outputs)
        # This is (batch_size, num_decoding_steps, num_encoder_outputs)
        if len(step_attention_weights) > 0:
            attention_matrix = torch.cat(step_attention_weights, 0)

        attention_matrix.unsqueeze_(0)
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities,
            "predictions": all_predictions,
            "top_spans": top_spans,
            "attention_matrix": attention_matrix,
            "top_spans_scores": top_span_scores
        }
        if target_tokens:
            target_mask = get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict[
                "loss"] = loss  #+ top_span_scores.squeeze().view(-1).index_select(0, top_span_mask.view(-1).long()).sum()
        return output_dict
Exemplo n.º 10
0
    def _compute_representations_doc(
            self,  # type: ignore
            spans_batched: torch.IntTensor,
            span_mask_batched,
            span_embeddings_batched,  # TODO(dwadden) add type.
            sentence_lengths,
            ix,
            coref_labels_batched: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Run the forward pass for a single document.

        Important: This function assumes that sentences are going to be passed in in sorted order,
        from the same document.
        """
        # TODO(dwadden) How to handle case where only one span from a cluster makes it into the
        # minibatch? Should I get rid of the cluster?
        # TODO(dwadden) Write quick unit tests for correctness, time permitting.
        span_ix = span_mask_batched.view(-1).nonzero(
            as_tuple=False).squeeze()  # Indices of the spans to keep.
        spans, span_embeddings = self._flatten_spans(spans_batched, span_ix,
                                                     span_embeddings_batched,
                                                     sentence_lengths)
        coref_labels = self._flatten_coref_labels(coref_labels_batched,
                                                  span_ix)

        document_length = sentence_lengths.sum().item()
        num_spans = spans.size(1)

        # Prune based on mention scores. Make sure we keep at least 1.
        num_spans_to_keep = max(
            2, int(math.ceil(self._spans_per_word * document_length)))

        # Since there's only one minibatch, there aren't any masked spans for us. The span mask is
        # always 1.
        span_mask = torch.ones(num_spans,
                               device=spans_batched.device).unsqueeze(0)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores,
         num_items_kept) = self._mention_pruner(span_embeddings, span_mask,
                                                num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(span_embeddings))

        coreference_scores = self.get_coref_scores(top_span_embeddings,
                                                   top_span_mention_scores,
                                                   valid_antecedent_indices,
                                                   valid_antecedent_offsets,
                                                   valid_antecedent_log_mask)

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "valid_antecedent_log_mask": valid_antecedent_log_mask,
            "valid_antecedent_offsets": valid_antecedent_offsets,
            "top_span_indices": top_span_indices,
            "top_span_mask": top_span_mask,
            "top_span_embeddings": top_span_embeddings,
            "flat_top_span_indices": flat_top_span_indices,
            "coref_labels": coref_labels,
            "coreference_scores": coreference_scores,
            "sentence_lengths": sentence_lengths,
            "span_ix": span_ix,
            "metadata": metadata
        }

        return output_dict
Exemplo n.º 11
0
    def forward(
        self,  # pylint: disable=arguments-differ
        embeddings: torch.FloatTensor,
        mask: torch.LongTensor,
        num_items_to_keep: Union[int, torch.LongTensor]
    ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor,
               torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model). May use the same k for all sentences in
        minibatch, or different k for each.

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``Union[int, torch.LongTensor]``, required.
            If a tensor of shape (batch_size), specifies the number of items to keep for each
            individual sentence in minibatch.
            If an int, keep the same number of items for all sentences.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring items.
            Has shape (batch_size, max_num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The corresponding mask for ``top_embeddings``.
            Has shape (batch_size, max_num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, max_num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, max_num_items_to_keep, 1).
        """
        # If an int was given for number of items to keep, construct tensor by repeating the value.
        if isinstance(num_items_to_keep, int):
            batch_size = mask.size(0)
            # Put the tensor on same device as the mask.
            num_items_to_keep = num_items_to_keep * torch.ones(
                [batch_size], dtype=torch.long, device=mask.device)

        max_items_to_keep = num_items_to_keep.max()

        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)
        # Shape: (batch_size, num_items, 1)
        scores = self._scorer(embeddings)

        if scores.size(-1) != 1 or scores.dim() != 3:
            raise ValueError(
                f"The scorer passed to Pruner must produce a tensor of shape"
                f"(batch_size, num_items, 1), but found shape {scores.size()}")
        # Make sure that we don't select any masked items by setting their scores to be very
        # negative.  These are logits, typically, so -1e20 should be plenty negative.
        scores = util.replace_masked_values(scores, mask, -1e20)

        # Shape: (batch_size, max_num_items_to_keep, 1)
        _, top_indices = scores.topk(max_items_to_keep, 1)

        # Mask based on number of items to keep for each sentence.
        # Shape: (batch_size, max_num_items_to_keep)
        top_indices_mask = util.get_mask_from_sequence_lengths(
            num_items_to_keep, max_items_to_keep)
        top_indices_mask = top_indices_mask.byte()

        # Shape: (batch_size, max_num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Fill all masked indices with largest "top" index for that sentence, so that all masked
        # indices will be sorted to the end.
        # Shape: (batch_size, 1)
        fill_value, _ = top_indices.max(dim=1)
        fill_value = fill_value.unsqueeze(-1)
        # Shape: (batch_size, max_num_items_to_keep)
        top_indices = torch.where(top_indices_mask, top_indices, fill_value)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size * max_num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(
            top_indices, num_items)

        # Shape: (batch_size, max_num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices,
                                                   flat_top_indices)

        # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside
        # the top k for each sentence.
        # Shape: (batch_size, max_num_items_to_keep)
        sequence_mask = util.batched_index_select(mask, top_indices,
                                                  flat_top_indices)
        sequence_mask = sequence_mask.squeeze(-1).byte()
        top_mask = top_indices_mask & sequence_mask
        top_mask = top_mask.long()

        # Shape: (batch_size, max_num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices,
                                               flat_top_indices)

        return top_embeddings, top_mask, top_indices, top_scores
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(global_attention_logits,
                                                          span_indices,
                                                          flat_span_indices).squeeze(-1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.masked_softmax(span_attention_logits, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(-1).float()

        return attended_text_embeddings
Exemplo n.º 13
0
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # shape (batch_size, num_spans)
        # span_starts, span_ends = span_indices.split(1, dim=-1)
        # batch_size, max_seq_len, _ = sequence_tensor.shape
        # max_span_num = span_indices.shape[1]
        # range_vector = util.get_range_vector(max_seq_len, util.get_device_of(sequence_tensor)).repeat(
        #     (batch_size, max_span_num, 1))
        # att_mask = (span_ends >= range_vector) - (span_starts > range_vector)
        # att_mask = att_mask * span_mask.unsqueeze(-1)
        # res = self._attention(sequence_tensor.repeat((max_span_num,1,1)), att_mask)

        # combined_tensors = util.combine_tensors(self._combination, [start_embeddings, end_embeddings])

        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        # global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)

        span_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        span_embeddings = span_embeddings.max(2)[0]

        if self._span_width_embedding is not None:
            # Embed the span widths and concatenate to the rest of the representations.
            if self._bucket_widths:
                span_widths = util.bucket_values(span_ends - span_starts,
                                                 num_total_buckets=self._num_width_embeddings)
            else:
                span_widths = span_ends - span_starts
            span_widths = span_widths.squeeze(-1)
            span_width_embeddings = self._span_width_embedding(span_widths)
            combined_tensors = torch.cat([span_embeddings, span_width_embeddings], -1)
        else:
            combined_tensors = span_embeddings
        if span_indices_mask is not None:
            return combined_tensors * span_indices_mask.unsqueeze(-1).float()

        return combined_tensors
Exemplo n.º 14
0
    def _get_entity_span_tokens_embeddings(
            sentence_repr: torch.Tensor,
            entity_span_indices: torch.LongTensor) -> List[torch.Tensor]:
        """
            Most of the codes are extracted from `forward()` method of
            `https://github.com/allenai/allennlp/blob/master/allennlp/modules/span_extractors/self_attentive_span_extractor.py#L45`
        :param sentence_repr: (batch_size, seq_len, embedding_dim)
        :param entity_span_indices: (batch_size, num_spans, 2); last dim `0` for start, `1` for end (inclusive)
        :return: List (length 2) of Tensor:
            span_embeddings: (batch_size, num_spans, max_batch_span_width, embedding_dim)
            span_mask      : (batch_size, num_spans, max_batch_span_width)
        """
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = entity_span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        # global_attention_logits = self._global_attention(sentence_repr)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(
            max_batch_span_width,
            util.get_device_of(sentence_repr)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_mask = span_mask * (raw_span_indices >= 0).float()

        span_indices = torch.nn.functional.relu(
            raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, sentence_repr.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sentence_repr,
                                                    span_indices,
                                                    flat_span_indices)
        return [span_embeddings, span_mask]
Exemplo n.º 15
0
    def forward(
        self,  # pylint: disable=arguments-differ
        embeddings: torch.FloatTensor,
        mask: torch.LongTensor,
        num_items_to_keep: Union[int, torch.LongTensor],
        class_scores: torch.FloatTensor = None,
        gold_labels: torch.long = None
    ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor,
               torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model). May use the same k for all sentences in
        minibatch, or different k for each.

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``Union[int, torch.LongTensor]``, required.
            If a tensor of shape (batch_size), specifies the number of items to keep for each
            individual sentence in minibatch.
            If an int, keep the same number of items for all sentences.
        class_scores:
           Class scores to be used with entity beam.
        candidate_labels: If in debugging mode, use gold labels to get beam.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring items.
            Has shape (batch_size, max_num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The corresponding mask for ``top_embeddings``.
            Has shape (batch_size, max_num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, max_num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, max_num_items_to_keep, 1).
        num_items_kept
        """
        # If an int was given for number of items to keep, construct tensor by repeating the value.
        if isinstance(num_items_to_keep, int):
            batch_size = mask.size(0)
            # Put the tensor on same device as the mask.
            num_items_to_keep = num_items_to_keep * torch.ones(
                [batch_size], dtype=torch.long, device=mask.device)

        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)

        # Shape: (batch_size, num_items, 1)
        # If entity beam is one, use the class scores. Else ignore them and use the scorer.
        if self._entity_beam:
            scores, _ = class_scores.max(dim=-1)
            scores = scores.unsqueeze(-1)
        # If gold beam is one, give a score of 0 wherever the gold label is non-zero (indicating a
        # non-null label), otherwise give a large negative number.
        elif self._gold_beam:
            scores = torch.where(
                gold_labels > 0,
                torch.zeros_like(gold_labels, dtype=torch.float),
                -1e20 * torch.ones_like(gold_labels, dtype=torch.float))
            scores = scores.unsqueeze(-1)
        else:
            scores = self._scorer(embeddings)

        # If we're only keeping items that score above a given threshold, change the number of kept
        # items here.
        if self._min_score_to_keep is not None:
            num_good_items = torch.sum(scores > self._min_score_to_keep,
                                       dim=1).squeeze()
            num_items_to_keep = torch.min(num_items_to_keep, num_good_items)
        # If gold beam is on, keep the gold items.
        if self._gold_beam:
            num_items_to_keep = torch.sum(gold_labels > 0, dim=1)

        # Always keep at least one item to avoid edge case with empty matrix.
        max_items_to_keep = max(num_items_to_keep.max().item(), 1)

        if scores.size(-1) != 1 or scores.dim() != 3:
            raise ValueError(
                f"The scorer passed to Pruner must produce a tensor of shape"
                f"(batch_size, num_items, 1), but found shape {scores.size()}")
        # Make sure that we don't select any masked items by setting their scores to be very
        # negative.  These are logits, typically, so -1e20 should be plenty negative.
        # NOTE(`mask` needs to be a byte tensor now.)
        scores = util.replace_masked_values(scores, mask.byte(), -1e20)

        # Shape: (batch_size, max_num_items_to_keep, 1)
        _, top_indices = scores.topk(max_items_to_keep, 1)

        # Mask based on number of items to keep for each sentence.
        # Shape: (batch_size, max_num_items_to_keep)
        top_indices_mask = util.get_mask_from_sequence_lengths(
            num_items_to_keep, max_items_to_keep)
        top_indices_mask = top_indices_mask.bool()

        # Shape: (batch_size, max_num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Fill all masked indices with largest "top" index for that sentence, so that all masked
        # indices will be sorted to the end.
        # Shape: (batch_size, 1)
        fill_value, _ = top_indices.max(dim=1)
        fill_value = fill_value.unsqueeze(-1)
        # Shape: (batch_size, max_num_items_to_keep)
        top_indices = torch.where(top_indices_mask, top_indices, fill_value)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size * max_num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(
            top_indices, num_items)

        # Shape: (batch_size, max_num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices,
                                                   flat_top_indices)

        # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside
        # the top k for each sentence.
        # Shape: (batch_size, max_num_items_to_keep)
        sequence_mask = util.batched_index_select(mask, top_indices,
                                                  flat_top_indices)
        sequence_mask = sequence_mask.squeeze(-1).bool()
        top_mask = top_indices_mask & sequence_mask
        top_mask = top_mask.long()

        # Shape: (batch_size, max_num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices,
                                               flat_top_indices)

        return top_embeddings, top_mask, top_indices, top_scores, num_items_to_keep
Exemplo n.º 16
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.last_dim_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood
        return output_dict
Exemplo n.º 17
0
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        metadata: List[Dict[str, Any]],
        doc_span_offsets: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        doc_truth_spans: torch.IntTensor = None,
        doc_spans_in_truth: torch.IntTensor = None,
        doc_relation_labels: torch.Tensor = None,
        truth_spans: List[Set[Tuple[int, int]]] = None,
        doc_relations=None,
        doc_ner_labels: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:  # add matrix from datareader
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        doc_ner_labels : ``torch.IntTensor``.
            A tensor of shape # TODO,
            ...
        doc_span_offsets : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_truth_spans : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, 1),
            ...
        doc_spans_in_truth : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_relation_labels : ``torch.Tensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans),
            ...

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        batch_size = len(spans)
        document_length = text_embeddings.size(1)
        max_sentence_length = max(
            len(sentence) for document in metadata
            for sentence in document['doc_tokens'])
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # TODO features dropout
        # Shape: (batch_size, num_spans, embedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_relex_spans_to_keep = int(
            math.floor(self._relex_spans_per_word * max_sentence_length))

        # Shapes:
        # (batch_size, num_spans_to_keep, span_dim),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # Shape: (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = dict()

        output_dict["top_spans"] = top_spans
        output_dict["antecedent_indices"] = valid_antecedent_indices
        output_dict["predicted_antecedents"] = predicted_antecedents

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]

        # Shape: (,)
        loss = 0

        # Shape: (batch_size, max_sentences, max_spans)
        doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float()
        # Shape: (batch_size, max_sentences, num_spans, span_dim)
        doc_span_embeddings = util.batched_index_select(
            span_embeddings,
            doc_span_offsets.squeeze(-1).long().clamp(min=0))

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        pruned = self._relex_mention_pruner(
            doc_span_embeddings,
            doc_span_mask,
            num_items_to_keep=num_relex_spans_to_keep,
            pass_through=['num_items_to_keep'])
        (top_relex_span_embeddings, top_relex_span_mask,
         top_relex_span_indices, top_relex_span_mention_scores) = pruned

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        top_relex_span_mask = top_relex_span_mask.unsqueeze(-1)

        # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2)  # TODO do we need for a mask?
        doc_spans = util.batched_index_select(
            spans,
            doc_span_offsets.clamp(0).squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2)
        top_relex_spans = nd_batched_index_select(doc_spans,
                                                  top_relex_span_indices)

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep).
        (relex_span_pair_embeddings,
         relex_span_pair_mask) = self._compute_relex_span_pair_embeddings(
             top_relex_span_embeddings, top_relex_span_mask.squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels)
        relex_scores = self._compute_relex_scores(
            relex_span_pair_embeddings, top_relex_span_mention_scores)
        output_dict['relex_scores'] = relex_scores
        output_dict['top_relex_spans'] = top_relex_spans

        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)
            antecedent_labels_ = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long(
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability x to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs)
            negative_marginal_log_likelihood *= top_span_mask.squeeze(
                -1).float()
            negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum(
            )

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            coref_loss = negative_marginal_log_likelihood
            output_dict['coref_loss'] = coref_loss
            loss += self._loss_coref_weight * coref_loss

        if doc_relations is not None:

            # The adjacency matrix for relation extraction is very sparse.
            # As it is not just sparse, but row/column sparse (only few
            # rows and columns are non-zero and in that case these rows/columns
            # are not sparse), we implemented our own matrix for the case.
            # Here we have indices of truth spans and mapping, using which
            # we map prediction matrix on truth matrix.
            # TODO Add teacher forcing support.

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep),
            relative_indices = top_relex_span_indices
            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1),
            compressed_indices = nd_batched_padded_index_select(
                doc_spans_in_truth, relative_indices)

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans)
            gold_pruned_rows = nd_batched_padded_index_select(
                doc_relation_labels,
                compressed_indices.squeeze(-1),
                padding_value=0)
            gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3,
                                                        2).contiguous()

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep)
            gold_pruned_matrices = nd_batched_padded_index_select(
                gold_pruned_rows,
                compressed_indices.squeeze(-1),
                padding_value=0)  # pad with epsilon
            gold_pruned_matrices = gold_pruned_matrices.permute(
                0, 1, 3, 2).contiguous()

            # TODO log_mask relex score before passing
            relex_loss = nd_cross_entropy_with_logits(relex_scores,
                                                      gold_pruned_matrices,
                                                      relex_span_pair_mask)
            output_dict['relex_loss'] = relex_loss

            self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2),
                                       truth_spans)
            self._compute_relex_metrics(output_dict, doc_relations)

            loss += self._loss_relex_weight * relex_loss

        if doc_ner_labels is not None:
            # Shape: (batch_size, max_sentences, num_spans, num_ner_classes)
            ner_scores = self._ner_scorer(doc_span_embeddings)
            output_dict['ner_scores'] = ner_scores

            ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels,
                                                    doc_span_mask)
            output_dict['ner_loss'] = ner_loss
            loss += self._loss_ner_weight * ner_loss

        if not isinstance(loss, int):  # If loss is not yet modified
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 18
0
    def forward(self, # pylint: disable=arguments-differ
                embeddings: torch.FloatTensor,
                mask: torch.LongTensor,
                num_items_to_keep: int) -> Tuple[torch.FloatTensor, torch.LongTensor,
                                                 torch.LongTensor, torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model).

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``int``, required.
            The number of items to keep when pruning.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The corresponding mask for ``top_embeddings``.
            Has shape (batch_size, num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, 1).
        """
        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)
        # Shape: (batch_size, num_items, 1)
        scores = self._scorer(embeddings)

        if scores.size(-1) != 1 or scores.dim() != 3:
            raise ValueError(f"The scorer passed to Pruner must produce a tensor of shape"
                             f"(batch_size, num_items, 1), but found shape {scores.size()}")
        # Make sure that we don't select any masked items by setting their scores to be very
        # negative.  These are logits, typically, so -1e20 should be plenty negative.
        scores = util.replace_masked_values(scores, mask, -1e20)

        # Shape: (batch_size, num_items_to_keep, 1)
        _, top_indices = scores.topk(num_items_to_keep, 1)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size, num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Shape: (batch_size * num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(top_indices, num_items)

        # Shape: (batch_size, num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices, flat_top_indices)
        # Shape: (batch_size, num_items_to_keep)
        top_mask = util.batched_index_select(mask, top_indices, flat_top_indices)

        # Shape: (batch_size, num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices, flat_top_indices)

        return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
Exemplo n.º 19
0
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            spans: torch.IntTensor,
            labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1])

        # Shape:    (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Shape: (batch_size, num_spans_to_keep, class_num + 1)
        ne_scores = self._compute_named_entity_scores(top_span_embeddings)

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_named_entities = ne_scores.max(2)

        output_dict = {
            "top_spans": top_spans,
            "predicted_named_entities": predicted_named_entities
        }
        if labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices).squeeze(-1)
            negative_log_likelihood = F.cross_entropy(
                ne_scores.reshape(-1, self.class_num),
                pruned_gold_labels.reshape(-1))

            pruner_loss = F.binary_cross_entropy_with_logits(
                top_span_mention_scores.reshape(-1),
                (pruned_gold_labels.reshape(-1) != 0).float())
            loss = negative_log_likelihood + pruner_loss
            output_dict["loss"] = loss
            output_dict["pruner_loss"] = pruner_loss
            batch_size, _ = labels.shape
            all_scores = ne_scores.new_zeros(
                [batch_size * num_spans, self.class_num])
            all_scores[:, 0] = 1
            all_scores[flat_top_span_indices] = ne_scores.reshape(
                -1, self.class_num)
            all_scores = all_scores.reshape(
                [batch_size, num_spans, self.class_num])
            self._metric_all(all_scores, labels)
            self._metric_avg(all_scores, labels)
        return output_dict
    def forward(
        self,  # pylint: disable=arguments-differ
        embeddings: torch.FloatTensor,
        mask: torch.LongTensor,
        num_items_to_keep: int,
        get_scores: bool = False,
        scores: torch.FloatTensor = None,
        **kwargs
    ) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor,
               torch.FloatTensor]:
        """
        Extracts the top-k scoring items with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that downstream
        components can rely on the original ordering (e.g., for knowing what spans are valid
        antecedents in a coreference resolution model).

        Parameters
        ----------
        embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for
            each item in the list that we want to prune.
        mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_items), denoting unpadded elements of
            ``embeddings``.
        num_items_to_keep : ``int``, required.
            The number of items to keep when pruning.

        Returns
        -------
        top_embeddings : ``torch.FloatTensor``
            The representations of the top-k scoring itemss.
            Has shape (batch_size, num_items_to_keep, embedding_size).
        top_mask : ``torch.LongTensor``
            The coresponding mask for ``top_embeddings``.
            Has shape (batch_size, num_items_to_keep).
        top_indices : ``torch.IntTensor``
            The indices of the top-k scoring items into the original ``embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original items, if each item is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, num_items_to_keep).
        top_item_scores : ``torch.FloatTensor``
            The values of the top-k scoring items.
            Has shape (batch_size, num_items_to_keep, 1).
        """
        mask = mask.unsqueeze(-1)
        num_items = embeddings.size(1)
        if scores is None:
            # Shape: (batch_size, num_items, 1)
            scores = self._scorer(embeddings)

            if scores.size(-1) != 1 or scores.dim() != 3:
                raise ValueError(
                    f"The scorer passed to SpanPruner must produce a tensor of shape"
                    f"(batch_size, num_items, 1), but found shape {scores.size()}"
                )
            # Make sure that we don't select any masked items by setting their scores to be very
            # negative.  These are logits, typically, so -1e20 should be plenty negative.
            scores = util.replace_masked_values(scores, mask, -1e20)

        if get_scores:
            return scores

        # Shape: (batch_size, num_items_to_keep, 1)
        _, top_indices = scores.topk(num_items_to_keep, 1)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``embeddings`` tensor).
        top_indices, _ = torch.sort(top_indices, 1)

        # Shape: (batch_size, num_items_to_keep)
        top_indices = top_indices.squeeze(-1)

        # Shape: (batch_size * num_items_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select items for each element in the batch.
        flat_top_indices = util.flatten_and_batch_shift_indices(
            top_indices, num_items)

        # Shape: (batch_size, num_items_to_keep, embedding_size)
        top_embeddings = util.batched_index_select(embeddings, top_indices,
                                                   flat_top_indices)
        # Shape: (batch_size, num_items_to_keep)
        top_mask = util.batched_index_select(mask, top_indices,
                                             flat_top_indices)

        # Shape: (batch_size, num_items_to_keep, 1)
        top_scores = util.batched_index_select(scores, top_indices,
                                               flat_top_indices)

        return top_embeddings, top_mask.squeeze(-1), top_indices, top_scores
Exemplo n.º 21
0
    def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                sequence_mask: torch.LongTensor = None,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        batch_size, num_spans = span_indices.size()[:2]

        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # compute span head weight
        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(global_attention_logits,
                                                          span_indices,
                                                          flat_span_indices).squeeze(-1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_head_weights = util.masked_softmax(span_attention_logits, span_mask)

        # get head words indices
        top_num_heads = min(self._num_heads, max_batch_span_width)
        # Shape: (batch_size, num_spans, num_heads)
        span_head_ind = span_head_weights.topk(top_num_heads, -1)[1]
        # make sure the index is consistent with the original order
        span_head_ind = torch.sort(span_head_ind, -1)[0]
        # Shape: (batch_size * num_spans * num_heads)
        flat_span_head_ind = util.flatten_and_batch_shift_indices(
            span_head_ind.view(-1, top_num_heads), max_batch_span_width)

        # select emb and mask
        # Shape: (batch_size, num_spans, num_heads)
        span_head_ind_external = util.batched_index_select(
            span_indices.view(-1, max_batch_span_width, 1),
            span_head_ind.view(-1, top_num_heads),
            flat_span_head_ind).view(*span_head_ind.size())
        # Shape: (batch_size, num_spans, num_heads)
        span_head_mask = util.batched_index_select(
            span_mask.view(-1, max_batch_span_width, 1),
            span_head_ind.view(-1, top_num_heads),
            flat_span_head_ind).view(*span_head_ind.size())
        # Shape: (batch_size, num_spans, num_heads, emb_dim)
        span_head_emb = util.batched_index_select(sequence_tensor, span_head_ind_external, flattened_indices=None)

        # concat with span token
        # Shape: (batch_size, num_spans, 1, emb_dim)
        span_token_emb = self._span_token_emb.view(1, 1, 1, -1).expand(batch_size, num_spans, -1, -1)
        # Shape: (batch_size, num_spans, num_heads + 1, emb_dim)
        span_head_emb = torch.cat([span_head_emb, span_token_emb], 2)
        # Shape: (batch_size, num_spans, num_heads + 1, emb_dim)
        span_head_mask = nn.ConstantPad1d((1, 0), 1)(span_head_mask)

        # self attention span representation
        span_head_emb = self._stacked_self_attention(
            span_head_emb.view(-1, top_num_heads + 1, self._input_dim),
            span_head_mask.view(-1, top_num_heads + 1))

        # aggregate
        # Shape: (batch_size, num_spans, num_heads + 1, emb_dim)
        span_head_emb = span_head_emb.view(batch_size, num_spans, top_num_heads + 1, self._output_dim)
        # Shape: (batch_size, num_spans, emb_dim)
        span_embeddings = span_head_emb[:, :, 0]

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        #attended_text_embeddings = util.weighted_sum(span_embeddings, span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return span_embeddings * span_indices_mask.unsqueeze(-1).float()

        return span_embeddings
Exemplo n.º 22
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        sentence_end: torch.LongTensor,
        spans: torch.LongTensor,
        span_labels: torch.LongTensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        start = time.time()
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            # token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            bert_embeddings, spans)

        if self._context_layer is not None:
            contextualized_embeddings = self._context_layer(
                embedded_text_input, mask)
            # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
            endpoint_span_embeddings = self._endpoint_span_extractor(
                contextualized_embeddings, spans)

            # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
            # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)
            span_embeddings = endpoint_span_embeddings
        else:
            span_embeddings = attended_span_embeddings

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * sequence_length))
        num_spans = spans.shape[1]
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep)
        verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat(
            1, 1, embedded_text_input.shape[-1])
        verb_embeddings = torch.gather(embedded_text_input, 1, verb_index)
        assert len(
            verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1
        verb_embeddings = verb_embeddings.squeeze(1)
        # print(verb_indicator.sum(1, keepdim=True) > 0)
        verb_embeddings = torch.where(
            (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                1, verb_embeddings.shape[-1]), verb_embeddings,
            torch.zeros_like(verb_embeddings))
        # print(verb_embeddings)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, spans.shape[1])
        span_embeddings = util.batched_index_select(span_embeddings,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)
        top_span_labels = util.batched_index_select(
            span_labels.unsqueeze(-1), top_span_indices,
            flat_top_span_indices).squeeze(-1)
        concatenated_span_embeddings = torch.cat(
            (span_embeddings, verb_embeddings.unsqueeze(1).repeat(
                1, span_embeddings.shape[1], 1)),
            dim=2)
        # print(concatenated_span_embeddings[:,:,:])
        hidden = self.hidden_layer(concatenated_span_embeddings)
        # print(hidden[1,:,:])
        # print(top_span_indices)
        # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])])
        # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels"))
        predictions = self.output_layer(hidden)
        # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)
        predictions = torch.cat(
            (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1)
        # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1))

        output_dict = {}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]),
                                  top_span_labels.view(-1)) *
                    top_span_mask.float().view(-1)
                    ).sum() / top_span_mask.float().sum()
            # print(top_span_labels)
            # print(predictions.argmax(-1))
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.get_tags(
                    top_spans, predictions, mask.shape[1], top_span_mask,
                    output_dict)
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print('G', batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict
Exemplo n.º 23
0
    def forward(
            self,  # pylint: disable=arguments-differ
            span_embeddings,
            span_mask,
            num_spans_to_keep):
        u"""
        Extracts the top-k scoring spans with respect to the scorer. We additionally return
        the indices of the top-k in their original order, not ordered by score, so that we
        can rely on the ordering to consider the previous k spans as antecedents for each
        span later.

        Parameters
        ----------
        span_embeddings : ``torch.FloatTensor``, required.
            A tensor of shape (batch_size, num_spans, embedding_size), representing
            the set of embedded span representations.
        span_mask : ``torch.LongTensor``, required.
            A tensor of shape (batch_size, num_spans), denoting unpadded elements
            of ``span_embeddings``.
        num_spans_to_keep : ``int``, required.
            The number of spans to keep when pruning.

        Returns
        -------
        top_span_embeddings : ``torch.FloatTensor``
            The span representations of the top-k scoring spans.
            Has shape (batch_size, num_spans_to_keep, embedding_size).
        top_span_mask : ``torch.LongTensor``
            The coresponding mask for ``top_span_embeddings``.
            Has shape (batch_size, num_spans_to_keep).
        top_span_indices : ``torch.IntTensor``
            The indices of the top-k scoring spans into the original ``span_embeddings``
            tensor. This is returned because it can be useful to retain pointers to
            the original spans, if each span is being scored by multiple distinct
            scorers, for instance. Has shape (batch_size, num_spans_to_keep).
        top_span_scores : ``torch.FloatTensor``
            The values of the top-k scoring spans.
            Has shape (batch_size, num_spans_to_keep, 1).
        """
        span_mask = span_mask.unsqueeze(-1)
        num_spans = span_embeddings.size(1)
        # Shape: (batch_size, num_spans, 1)
        span_scores = self._scorer(span_embeddings)

        if span_scores.size(-1) != 1 or span_scores.dim() != 3:
            raise ValueError(
                "The scorer passed to SpanPruner must produce a tensor of shape"
                "(batch_size, num_spans, 1), but found shape {span_scores.size()}"
            )
        # Make sure that we don't select any masked spans by
        # setting their scores to be -inf.
        span_scores += span_mask.log()

        # Shape: (batch_size, num_spans_to_keep, 1)
        _, top_span_indices = span_scores.topk(num_spans_to_keep, 1)

        # Now we order the selected indices in increasing order with
        # respect to their indices (and hence, with respect to the
        # order they originally appeared in the ``span_embeddings`` tensor).
        top_span_indices, _ = torch.sort(top_span_indices, 1)

        # Shape: (batch_size, num_spans_to_keep)
        top_span_indices = top_span_indices.squeeze(-1)

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Shape: (batch_size, num_spans_to_keep, embedding_size)
        top_span_embeddings = util.batched_index_select(
            span_embeddings, top_span_indices, flat_top_span_indices)
        # Shape: (batch_size, num_spans_to_keep)
        top_span_mask = util.batched_index_select(span_mask, top_span_indices,
                                                  flat_top_span_indices)

        # Shape: (batch_size, num_spans_to_keep, 1)
        top_span_scores = util.batched_index_select(span_scores,
                                                    top_span_indices,
                                                    flat_top_span_indices)

        return top_span_embeddings, top_span_mask.squeeze(
            -1), top_span_indices, top_span_scores
Exemplo n.º 24
0
    def _create_attended_span_representations(
            self, head_scores: torch.FloatTensor,
            text_embeddings: torch.FloatTensor, span_ends: torch.IntTensor,
            span_widths: torch.IntTensor) -> torch.FloatTensor:
        """
        Given a tensor of unnormalized attention scores for each word in the document, compute
        distributions over every span with respect to these scores by normalising the headedness
        scores for words inside the span.

        Given these headedness distributions over every span, weight the corresponding vector
        representations of the words in the span by this distribution, returning a weighted
        representation of each span.

        Parameters
        ----------
        head_scores : ``torch.FloatTensor``, required.
            Unnormalized headedness scores for every word. This score is shared for every
            candidate. The only way in which the headedness scores differ over different
            spans is in the set of words over which they are normalized.
        text_embeddings: ``torch.FloatTensor``, required.
            The embeddings with shape  (batch_size, document_length, embedding_size)
            over which we are computing a weighted sum.
        span_ends: ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 1), representing the end indices
            of each span.
        span_widths : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 1) representing the width of each
            span candidates.
        Returns
        -------
        attended_text_embeddings : ``torch.FloatTensor``
            A tensor of shape (batch_size, num_spans, embedding_dim) - the result of
            applying attention over all words within each candidate span.
        """
        # Shape: (1, 1, max_span_width)
        max_span_range_indices = util.get_range_vector(
            self._max_span_width, text_embeddings.is_cuda).view(1, 1, -1)

        # Shape: (batch_size, num_spans, max_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the document
        # are of a smaller width than max_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        # Spans
        span_indices = F.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, text_embeddings.size(1))

        # Shape: (batch_size, num_spans, max_span_width, embedding_dim)
        span_text_embeddings = util.batched_index_select(
            text_embeddings, span_indices, flat_span_indices)

        # Shape: (batch_size, num_spans, max_span_width)
        span_head_scores = util.batched_index_select(
            head_scores, span_indices, flat_span_indices).squeeze(-1)

        # Shape: (batch_size, num_spans, max_span_width)
        span_head_weights = util.last_dim_softmax(span_head_scores, span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised head score distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_text_embeddings,
                                                     span_head_weights)

        return attended_text_embeddings
Exemplo n.º 25
0
def span_pruner(embeddings,
                scores,
                mask,
                seq_length,
                spans_per_word=1,
                num_keep=None):
    """

        Based on AllenNLP allennlp.modules.Pruner from release 0.84


        Parameters
        ----------

        logits: (batch_size, num_spans, num_tags)
        mask: (batch_size, num_spans)
        num_keep: int OR torch.LongTensor
                If a tensor of shape (batch_size), specifies the
                number of items to keep for each
                individual sentence in minibatch.
                If an int, keep the same number of items for all sentences.


        """

    #batch_size, num_items, num_tags = tuple(logits.shape)
    batch_size, num_items = tuple(scores.shape)

    # Number to keep not provided, so use spans per word
    if num_keep is None:
        num_keep = seq_length * spans_per_word
        num_keep = torch.max(num_keep, torch.ones_like(num_keep))

    # If an int was given for number of items to keep, construct tensor by repeating the value.
    if isinstance(num_keep, int):
        num_keep = num_keep * torch.ones(
            [batch_size], dtype=torch.long, device=mask.device)

    # Maximum number to keep
    max_keep = num_keep.max()

    # Get scores from logits
    # (batch_size, num_spans)
    # scores = logit_scorer(logits)

    # Set overlapping span scores large neg number
    #if prune_overlapping:
    #    scores = overlap_filter(scores, span_overlaps)

    # Add dimension
    scores = scores.unsqueeze(-1)
    #embeddings = embeddings.unsqueeze(-1)

    # Check scores dimensionality
    if scores.size(-1) != 1 or scores.dim() != 3:
        raise ValueError(
            f"The scorer passed to Pruner must produce a tensor of shape"
            f"(batch_size, num_items, 1), but found shape {scores.size()}")

    # Make sure that we don't select any masked items by setting their scores to be very
    # negative.  These are logits, typically, so -1e20 should be plenty negative.
    #print("scores", scores.shape)
    #print('mask', mask.shape)
    mask = mask.unsqueeze(-1).bool()  #type(torch.BoolTensor)
    #print('mask', mask.shape, mask.type)
    scores = util.replace_masked_values(scores, mask, NEG_FILL)

    # Shape: (batch_size, max_num_items_to_keep, 1)
    _, top_indices = scores.topk(max_keep, 1)

    # Mask based on number of items to keep for each sentence.
    # Shape: (batch_size, max_num_items_to_keep)
    top_indices_mask = util.get_mask_from_sequence_lengths(num_keep, max_keep)
    top_indices_mask = top_indices_mask.bool()

    # Shape: (batch_size, max_num_items_to_keep)
    top_indices = top_indices.squeeze(-1)

    # Fill all masked indices with largest "top" index for that sentence, so that all masked
    # indices will be sorted to the end.
    # Shape: (batch_size, 1)
    fill_value, _ = top_indices.max(dim=1)
    fill_value = fill_value.unsqueeze(-1)
    # Shape: (batch_size, max_num_items_to_keep)
    top_indices = torch.where(top_indices_mask, top_indices, fill_value)
    # Now we order the selected indices in increasing order with
    # respect to their indices (and hence, with respect to the
    # order they originally appeared in the ``embeddings`` tensor).
    top_indices, _ = torch.sort(top_indices, 1)

    # Shape: (batch_size * max_num_items_to_keep)
    # torch.index_select only accepts 1D indices, but here
    # we need to select items for each element in the batch.
    flat_indices = util.flatten_and_batch_shift_indices(top_indices, num_items)

    # Combine the masks on spans that are out-of-bounds, and the mask on spans that are outside
    # the top k for each sentence.
    # Shape: (batch_size, max_num_items_to_keep)
    sequence_mask = util.batched_index_select(mask, top_indices, flat_indices)
    sequence_mask = sequence_mask.squeeze(-1).bool()
    top_mask = top_indices_mask & sequence_mask
    top_mask = top_mask.long()

    # Shape: (batch_size, max_num_items_to_keep, 1)
    top_scores = util.batched_index_select(scores, top_indices, flat_indices)
    top_embeddings = util.batched_index_select(embeddings, top_indices,
                                               flat_indices)

    # Shape: (batch_size, max_num_items_to_keep)
    top_scores = top_scores.squeeze(-1)
    #top_embeddings = top_embeddings.squeeze(-1)

    return (top_indices, top_embeddings, top_scores, top_mask)
Exemplo n.º 26
0
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        if self._use_gold_mentions:
            if text_embeddings.is_cuda:
                device = torch.device("cuda")
            else:
                device = torch.device("cpu")

            s = [
                torch.as_tensor(pair, dtype=torch.long, device=device)
                for cluster in metadata[0]["clusters"] for pair in cluster
            ]
            gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1)

            span_mask = spans.unsqueeze(2) - gm
            span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1]
                                                        == 0)
            span_mask, _ = (span_mask == 2).max(-1)
            num_spans = span_mask.sum().item()
            span_mask = span_mask.float()
        else:
            span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
            num_spans = spans.size(1)
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = self._generate_valid_antecedents(
            num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings,
            top_span_mention_scores,
            candidate_antecedent_mention_scores,
            valid_antecedent_log_mask,
        )

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            coreference_log_probs = util.last_dim_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
    def forward(
        self,
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
        get_scores: bool = False,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        num_models = len(self.submodels)
        mention_results = [
            submodel.get_mention_scores(text, spans)
            for submodel in self.submodels
        ]

        # extract return values
        mask = mention_results[0]['mask']
        num_spans_to_keep = mention_results[0]['num_items_to_keep']
        text_mask = mention_results[0]['text_mask']
        all_mention_scores = torch.stack(
            [mention_results[i]['scores'] for i in range(num_models)])
        # average across mention scores
        avg_mention_scores = all_mention_scores.mean(0)
        # ensure we don't select masked items
        avg_mention_scores = util.replace_masked_values(
            avg_mention_scores, mask, -1e20)
        # prune mentions with averaged scores
        _, top_span_indices_ensemble = avg_mention_scores.topk(
            num_spans_to_keep, 1)
        top_span_indices_ensemble, _ = torch.sort(top_span_indices_ensemble, 1)
        top_span_indices_ensemble = top_span_indices_ensemble.squeeze(-1)
        flat_top_span_indices_ensemble = util.flatten_and_batch_shift_indices(
            top_span_indices_ensemble, avg_mention_scores.size(1))
        top_span_mask = util.batched_index_select(
            mask, top_span_indices_ensemble, flat_top_span_indices_ensemble)
        top_span_mention_scores = util.batched_index_select(
            avg_mention_scores, top_span_indices_ensemble,
            flat_top_span_indices_ensemble)

        # feed averaged mention scores and top mentions back into model
        coref_scores_results = [
            submodel.get_coreference_scores(
                spans=spans,
                top_span_mention_scores=top_span_mention_scores,
                num_spans_to_keep=num_spans_to_keep,
                top_span_indices=top_span_indices_ensemble,
                flat_top_span_indices=flat_top_span_indices_ensemble,
                top_span_mask=top_span_mask,
                top_span_embeddings=util.batched_index_select(
                    mention_results[i]['embeddings'],
                    top_span_indices_ensemble,
                    flat_top_span_indices_ensemble,
                ),
                text_mask=text_mask,
                get_scores=True,
            ) for i, submodel in enumerate(self.submodels)
        ]

        # extract return values (should be the same)
        top_spans = coref_scores_results[0]['output_dict']['top_spans']
        valid_antecedent_indices = coref_scores_results[0]['output_dict'][
            'antecedent_indices']
        valid_antecedent_log_mask = coref_scores_results[0]['ant_mask']
        all_coref_scores = torch.stack([
            coref_scores_results[i]['output_dict']['coreference_scores']
            for i in range(num_models)
        ])
        # average across coref scores
        avg_coref_scores = all_coref_scores.mean(0)
        # obtain predictions with averaged scores
        _, ensemble_predicted_antecedents = avg_coref_scores.max(2)
        ensemble_predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": ensemble_predicted_antecedents
        }
        if get_scores:
            output_dict["coreference_scores"] = avg_coref_scores
            output_dict['top_span_indices'] = top_span_indices_ensemble

        # feed averaged coreference scores back to model
        # since just running evaluation, should be the same result regardless of submodel
        output_dict = self.submodels[0].score_spans_if_labels(
            output_dict=output_dict,
            span_labels=span_labels,
            metadata=metadata,
            top_span_indices=top_span_indices_ensemble,
            flat_top_span_indices=flat_top_span_indices_ensemble,
            top_span_mask=top_span_mask,
            top_spans=top_spans,
            valid_antecedent_indices=valid_antecedent_indices,
            valid_antecedent_log_mask=valid_antecedent_log_mask,
            coreference_scores=avg_coref_scores,
            predicted_antecedents=ensemble_predicted_antecedents,
        )

        self._mention_recall(output_dict['top_spans'], metadata)
        self._conll_coref_scores(output_dict['top_spans'],
                                 output_dict['antecedent_indices'],
                                 output_dict['predicted_antecedents'],
                                 metadata)
        if get_scores:
            output_dict['coreference_scores_models'] = all_coref_scores
        return output_dict
Exemplo n.º 28
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Exemplo n.º 29
0
    def forward(self, span_embeddings, span_children, span_children_mask):
        batch, sequence, children_num, _ = span_children.size()
        # (batch, sequence, children_num)
        span_children = span_children.squeeze(-1)

        for t in range(self._tree_prop):

            flat_span_indices = util.flatten_and_batch_shift_indices(span_children, span_embeddings.size(1))
            # (batch, sequence, children_num, span_emb_dim)
            children_span_embeddings = util.batched_index_select(span_embeddings, span_children, flat_span_indices)

            if self._tree_children == 'attention':
                # (batch, sequence, children_num)
                attention_scores = self._global_attention(children_span_embeddings).squeeze(-1)
                # (batch, sequence, children_num)
                attention_scores_softmax = util.masked_softmax(attention_scores, span_children_mask, dim=2)
                # attention_scores_softmax = self.antecedent_softmax(attention_scores)
                # debug feili
                # for dim1 in attention_scores_softmax:
                #     for dim2 in dim1:
                #         pass
                # (batch, sequence, span_emb_dim)
                children_span_embeddings_merged = util.weighted_sum(children_span_embeddings, attention_scores_softmax)
            elif self._tree_children == 'pooling':
                children_span_embeddings_merged = util.masked_max(children_span_embeddings, span_children_mask.unsqueeze(-1), dim=2)
            elif self._tree_children == 'conv':
                masked_children_span_embeddings = children_span_embeddings * span_children_mask.unsqueeze(-1)

                masked_children_span_embeddings = masked_children_span_embeddings.view(batch * sequence, children_num, -1).transpose(1, 2)

                conv_children_span_embeddings = torch.nn.functional.relu(self._conv(masked_children_span_embeddings))

                conv_children_span_embeddings = conv_children_span_embeddings.transpose(1, 2).view(batch, sequence, children_num, -1)

                children_span_embeddings_merged = util.masked_max(conv_children_span_embeddings, span_children_mask.unsqueeze(-1), dim=2)
            elif self._tree_children == 'rnn':
                masked_children_span_embeddings = children_span_embeddings * span_children_mask.unsqueeze(-1)
                masked_children_span_embeddings = masked_children_span_embeddings.view(batch * sequence, children_num, -1)
                try : # if all spans don't have children in this batch, this code will report error
                    rnn_children_span_embeddings = self._encoder(masked_children_span_embeddings, span_children_mask.view(batch * sequence, children_num))
                except Exception as e:
                    rnn_children_span_embeddings = masked_children_span_embeddings

                rnn_children_span_embeddings = rnn_children_span_embeddings.view(batch, sequence, children_num, -1)
                forward_sequence, backward_sequence = rnn_children_span_embeddings.split(int(self._span_emb_dim / 2), dim=-1)
                children_span_embeddings_merged = torch.cat([forward_sequence[:,:,-1,:], backward_sequence[:,:,0,:]], dim=-1)
            else:
                raise RuntimeError
            # for dim1 in children_span_embeddings_attentioned:
            #     for dim2 in dim1:
            #         pass
            # (batch, sequence, 2*span_emb_dim)
            f_network_input = torch.cat([span_embeddings, children_span_embeddings_merged], dim=-1)
            # (batch, sequence, span_emb_dim)
            f_weights = self._f_network(f_network_input)
            # for dim1 in f_weights:
            #     for dim2 in dim1:
            #         pass
            # (batch, sequence, 1), if f_weights_mask=1, this span has at least one child
            f_weights_mask, _ = span_children_mask.max(dim=-1, keepdim=True)
            # for dim1 in f_weights_mask:
            #     for dim2 in dim1:
            #         pass
            # (batch, sequence, span_emb_dim), let the element of f_weights becomes 1 where f_weights_mask==0
            f_weights = util.replace_masked_values(f_weights, f_weights_mask, 1.0)
            # for dim1 in f_weights:
            #     for dim2 in dim1:
            #         pass
            # (batch, sequence, span_emb_dim)
            # for dim1 in span_embeddings:
            #     for dim2 in dim1:
            #         pass
            span_embeddings = f_weights * span_embeddings + (1.0 - f_weights) * children_span_embeddings_merged
            # for dim1 in combined_span_embeddings:
            #     for dim2 in dim1:
            #         pass

        span_embeddings = self._dropout(span_embeddings)

        return span_embeddings
Exemplo n.º 30
0
    def forward(
        self,  # type: ignore
        text: TextFieldTensors,
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters

        text : `TextFieldTensors`, required.
            The output of a `TextField` representing the text of
            the document.
        spans : `torch.IntTensor`, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of
            indices into the text of the document.
        span_labels : `torch.IntTensor`, optional (default = None).
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : `List[Dict[str, Any]]`, optional (default = None).
            A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys
            from this dictionary, which respectively have the original text and the annotated gold coreference
            clusters for that instance.

        # Returns

        An output dictionary consisting of:
        top_spans : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : `torch.IntTensor`
            A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        batch_size = spans.size(0)
        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text)

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)
        ).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep
        )

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices)
        # Shape: (batch_size, num_spans_to_keep, embedding_size)
        top_span_embeddings = util.batched_index_select(
            span_embeddings, top_span_indices, flat_top_span_indices
        )

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        # index to the indices of its allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        # we can use to make coreference decisions between valid span pairs.

        if self._coarse_to_fine:
            pruned_antecedents = self._coarse_to_fine_pruning(
                top_span_embeddings, top_span_mention_scores, top_span_mask, max_antecedents
            )
        else:
            pruned_antecedents = self._distance_pruning(
                top_span_embeddings, top_span_mention_scores, max_antecedents
            )

        # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors
        (
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
            top_antecedent_indices,
        ) = pruned_antecedents

        flat_top_antecedent_indices = util.flatten_and_batch_shift_indices(
            top_antecedent_indices, num_spans_to_keep
        )

        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        top_antecedent_embeddings = util.batched_index_select(
            top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
        )
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            top_span_embeddings,
            top_antecedent_embeddings,
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
        )

        for _ in range(self._inference_order - 1):
            dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1)
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,)
            top_antecedent_with_dummy_mask = torch.cat([dummy_mask, top_antecedent_mask], -1)
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
            attention_weight = util.masked_softmax(
                coreference_scores, top_antecedent_with_dummy_mask, memory_efficient=True
            )
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size)
            top_antecedent_with_dummy_embeddings = torch.cat(
                [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2
            )
            # Shape: (batch_size, num_spans_to_keep, embedding_size)
            attended_embeddings = util.weighted_sum(
                top_antecedent_with_dummy_embeddings, attention_weight
            )
            # Shape: (batch_size, num_spans_to_keep, embedding_size)
            top_span_embeddings = self._span_updating_gated_sum(
                top_span_embeddings, attended_embeddings
            )

            # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
            top_antecedent_embeddings = util.batched_index_select(
                top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
            )
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
            coreference_scores = self._compute_coreference_scores(
                top_span_embeddings,
                top_antecedent_embeddings,
                top_partial_coreference_scores,
                top_antecedent_mask,
                top_antecedent_offsets,
            )

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": top_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices
            )

            # Shape: (batch_size, num_spans_to_keep, max_antecedents)
            antecedent_labels = util.batched_index_select(
                pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices
            ).squeeze(-1)
            antecedent_labels = util.replace_masked_values(
                antecedent_labels, top_antecedent_mask, -100
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels
            )
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask.unsqueeze(-1)
            )
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(
                top_spans, top_antecedent_indices, predicted_antecedents, metadata
            )

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Exemplo n.º 31
0
    def forward(self,
                sequence_tensor,
                span_indices,
                sequence_mask=None,
                span_indices_mask=None):
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # shape (batch_size, sequence_length, 1)
        global_attention_logits = self._global_attention(sequence_tensor)

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(
            max_batch_span_width,
            util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(
            raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(
            span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor,
                                                    span_indices,
                                                    flat_span_indices)

        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_logits = util.batched_index_select(
            global_attention_logits, span_indices,
            flat_span_indices).squeeze(-1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        span_attention_weights = util.last_dim_softmax(span_attention_logits,
                                                       span_mask)

        # Do a weighted sum of the embedded spans with
        # respect to the normalised attention distributions.
        # Shape: (batch_size, num_spans, embedding_dim)
        attended_text_embeddings = util.weighted_sum(span_embeddings,
                                                     span_attention_weights)

        if span_indices_mask is not None:
            # Above we were masking the widths of spans with respect to the max
            # span width in the batch. Here we are masking the spans which were
            # originally passed in as padding.
            return attended_text_embeddings * span_indices_mask.unsqueeze(
                -1).float()

        return attended_text_embeddings