Exemplo n.º 1
0
    def _compute_antecedent_gold_labels(relation_labels: torch.IntTensor, coref_labels: torch.IntTensor):
        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        source_labels = relation_labels.unsqueeze(1)
        target_labels = relation_labels.unsqueeze(2)
        relation_indicator = (target_labels * source_labels).sum(-1).clamp(0, 1).float()

        source_labels = coref_labels.unsqueeze(1)
        target_labels = coref_labels.unsqueeze(2)
        coref_indicator = (target_labels * source_labels).sum(-1).clamp(0, 1).float()

        label = relation_indicator * (relation_indicator - coref_indicator)
        assert (label < 0).sum() == 0, breakpoint()

        return label
Exemplo n.º 2
0
    def mask_loc_logits(self, loc_logits, num_cands: torch.IntTensor):
        """
        Mask the padded candidates with an -inf score, so they will have a likelihood = 0 after softmax
        Args:
            loc_logits - output scores for each candidate in each sentence, size (batch, max_sents, max_cands)
            num_cands - total number of candidates in each instance of the given batch, size (batch,)
        """
        assert torch.max(num_cands) == loc_logits.size(-1)
        assert loc_logits.size(0) == num_cands.size(0)
        batch_size = loc_logits.size(0)
        max_cands = loc_logits.size(-1)

        # first, we create a mask tensor that masked all positions above the num_cands limit
        range_tensor = torch.arange(start=1, end=max_cands + 1)
        if self.use_cuda:
            range_tensor = range_tensor.cuda()
        range_tensor = range_tensor.unsqueeze(dim=0).expand(
            batch_size, max_cands)
        bool_range = torch.gt(
            range_tensor,
            num_cands.unsqueeze(dim=-1))  # find the off-limit positions
        assert bool_range.size() == (batch_size, max_cands)

        bool_range = bool_range.unsqueeze(dim=-2).expand_as(
            loc_logits)  # use this bool tensor to mask loc_logits
        masked_loc_logits = loc_logits.masked_fill(
            bool_range, value=float('-inf'))  # mask padded positions to -inf
        assert masked_loc_logits.size() == loc_logits.size()

        return masked_loc_logits
Exemplo n.º 3
0
    def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor,
                new_xyz: torch.Tensor,
                fps_idx: torch.IntTensor) -> torch.Tensor:
        r"""

        Parameters
        ----------
        radius : float
            radius of the balls
        nsample : int
            maximum number of features in the balls
        xyz : torch.Tensor
            (B, N, 3) xyz coordinates of the features
        new_xyz : torch.Tensor
            (B, npoint, 3) centers of the ball query

        Returns
        -------
        torch.Tensor
            (B, npoint, nsample) tensor with the indicies of the features that form the query balls
        """
        assert new_xyz.is_contiguous()
        assert xyz.is_contiguous()

        B, N, _ = xyz.size()
        npoint = new_xyz.size(1)
        idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()

        pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz,
                                     xyz, fps_idx, idx)

        return torch.cat([fps_idx.unsqueeze(2), idx], dim=2)
Exemplo n.º 4
0
def compute_antecedent_gold_labels(top_span_labels: torch.IntTensor,
                                   antecedent_labels: torch.IntTensor):
    """
    Generates a binary indicator for every pair of spans. This label is one if and
    only if the pair of spans belong to the same cluster. The labels are augmented
    with a dummy antecedent at the zeroth position, which represents the prediction
    that a span does not have any antecedent.

    Parameters
    ----------
    top_span_labels : ``torch.IntTensor``, required.
        The cluster id label for every span. The id is arbitrary,
        as we just care about the clustering. Has shape (batch_size, num_spans_to_keep).
    antecedent_labels : ``torch.IntTensor``, required.
        The cluster id label for every antecedent span. The id is arbitrary,
        as we just care about the clustering. Has shape
        (batch_size, num_spans_to_keep, max_antecedents).

    Returns
    -------
    pairwise_labels_with_dummy_label : ``torch.FloatTensor``
        A binary tensor representing whether a given pair of spans belong to
        the same cluster in the gold clustering.
        Has shape (batch_size, num_spans_to_keep, max_antecedents + 1).

    """
    # Shape: (batch_size, num_spans_to_keep, max_antecedents)
    top_span_labels = top_span_labels.unsqueeze(0)
    antecedent_labels = antecedent_labels.unsqueeze(0)
    target_labels = top_span_labels.expand_as(antecedent_labels)
    same_cluster_indicator = (target_labels == antecedent_labels).float()
    non_dummy_indicator = (target_labels >= 0).float()
    pairwise_labels = same_cluster_indicator * non_dummy_indicator

    # Shape: (batch_size, num_spans_to_keep, 1)
    dummy_labels = (1 - pairwise_labels).prod(-1, keepdim=True)

    # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
    pairwise_labels_with_dummy_label = torch.cat(
        [dummy_labels, pairwise_labels], -1)
    return pairwise_labels_with_dummy_label.squeeze(0)
Exemplo n.º 5
0
    def forward(
            self,  # type: ignore
            spans: torch.IntTensor,
            span_mask: torch.IntTensor,
            span_embeddings: torch.IntTensor,
            sentence_lengths: torch.Tensor,
            ner_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings

        self._active_namespace = f"{metadata.dataset}__ner_labels"
        if self._active_namespace not in self._ner_scorers:
            return {"loss": 0}

        scorer = self._ner_scorers[self._active_namespace]

        ner_scores = scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        ner_scores = util.replace_masked_values(ner_scores, mask.bool(), -1e20)
        # The dummy_scores are the score for the null label.
        dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1]
        dummy_scores = ner_scores.new_zeros(*dummy_dims)
        ner_scores = torch.cat((dummy_scores, ner_scores), -1)

        _, predicted_ner = ner_scores.max(2)

        predictions = self.predict(ner_scores.detach().cpu(),
                                   spans.detach().cpu(),
                                   span_mask.detach().cpu(), metadata)
        output_dict = {"predictions": predictions}

        if ner_labels is not None:
            metrics = self._ner_metrics[self._active_namespace]
            metrics(predicted_ner, ner_labels, span_mask)
            ner_scores_flat = ner_scores.view(
                -1, self._n_labels[self._active_namespace])
            ner_labels_flat = ner_labels.view(-1)
            mask_flat = span_mask.view(-1).bool()

            loss = self._loss(ner_scores_flat[mask_flat],
                              ner_labels_flat[mask_flat])

            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 6
0
    def forward(
            self,  # type: ignore
            text: Dict[str, Any],
            text_mask: torch.IntTensor,
            token_embeddings: torch.IntTensor,
            sentence_lengths: torch.Tensor,
            token_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        seq_scores = self._seq_scorer(token_embeddings)
        # Give large negative scores to masked-out elements.
        mask = text_mask.unsqueeze(-1)
        seq_scores = util.replace_masked_values(seq_scores, mask, -1e20)
        seq_scores[:, :, 0] *= text_mask

        _, predicted_seq = seq_scores.max(2)

        if self._label_scheme == 'flat':
            pred_spans = self._seq_metrics._decode_flat(
                predicted_seq, text_mask)
        elif self._label_scheme == 'stacked':
            pred_spans = self._seq_metrics._decode_stacked(
                predicted_seq, text_mask)
        else:
            raise RuntimeError("invalid label_scheme {}".format(
                self.label_scheme))

        output_dict = {
            "predicted_seq": predicted_seq,
            "predicted_seq_span": pred_spans
        }

        if token_labels is not None:
            self._seq_metrics(predicted_seq, token_labels, text_mask,
                              self.training)
            seq_scores_flat = seq_scores.view(-1, self._n_labels)
            seq_labels_flat = token_labels.view(-1)
            mask_flat = text_mask.view(-1).bool()

            loss = self._loss(seq_scores_flat[mask_flat],
                              seq_labels_flat[mask_flat])
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 7
0
    def forward(self, input_word_index: torch.IntTensor,
                h_state: torch.FloatTensor, c_state: torch.FloatTensor,
                enc_outputs: torch.FloatTensor, mask: torch.BoolTensor):
        """
        Pass inputs through the model.

        Args:
            input_word_index: torch.IntTensor[batch_size,]
            h_state: torch.FloatTensor[n_layers, batch_size, hidden_size]
            c_state: torch.FloatTensor[n_layers, batch_size, hidden_size]
            enc_outputs: torch.FloatTensor[seq_len, batch_size, hidden_size]
            mask: torch.BoolTensor[seq_len, batch_size, 1]

        Returns:
            logit: torch.FloatTensor[batch_size, vocab_size]
            h_state: torch.FloatTensor[n_layers, batch_size, hidden_size]
            c_state: torch.FloatTensor[n_layers, batch_size, hidden_size]
            attention_weights: torch.FloatTensor[seq_len, batch_size, 1]
        """
        embedded = self.embedding(input_word_index.unsqueeze(0))
        embedded = F.dropout(embedded, p=self.embedding_dropout)
        output, (h_state, c_state) = self.lstm(embedded, (h_state, c_state))
        # output: [seq_len=1, batch_size, hidden_size]
        # h_state: [n_layers, batch_size, hidden_size]
        # c_state: [n_layers, batch_size, hidden_size]

        # Compute attention weights
        attention_weights = self.attention_layer(
            h_state=h_state, enc_outputs=enc_outputs,
            mask=mask)  # attention_weights: [seq_len, batch_size, 1]
        # Compute the context vector
        context_vector = torch.bmm(
            enc_outputs.permute(1, 2, 0),  # [batch_size, hidden_size, seq_len]
            attention_weights.permute(1, 0, 2),  # [batch_size, seq_len, 1]
        ).permute(2, 0, 1)  # [1, batch_size, hidden_size]

        # New input: concatenate context_vector with hidden_states
        new_input = torch.cat((context_vector, output),
                              dim=2)  # [1, batch_size, hidden_size * 2]

        # Get logit
        x = self.fc1(new_input.squeeze(0))  # [batch_size, hidden_size]
        x = F.leaky_relu(x)
        x = F.dropout(x, p=self.dropout)
        logit = self.fc2(x)  # [batch_size, vocab_size]

        return logit, (h_state, c_state, attention_weights.squeeze(2))
Exemplo n.º 8
0
    def forward(
            self,  # type: ignore
            spans: torch.IntTensor,
            span_mask: torch.IntTensor,
            span_embeddings: torch.IntTensor,
            sentence_lengths: torch.Tensor,
            ner_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings
        ner_scores = self._ner_scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        ner_scores = util.replace_masked_values(ner_scores, mask, -1e20)
        # The dummy_scores are the score for the null label.
        dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1]
        dummy_scores = ner_scores.new_zeros(*dummy_dims)
        ner_scores = torch.cat((dummy_scores, ner_scores), -1)

        _, predicted_ner = ner_scores.max(2)

        output_dict = {
            "spans": spans,
            "span_mask": span_mask,
            "ner_scores": ner_scores,
            "predicted_ner": predicted_ner
        }

        if ner_labels is not None:
            self._ner_metrics(predicted_ner, ner_labels, span_mask)
            ner_scores_flat = ner_scores.view(-1, self._n_labels)
            ner_labels_flat = ner_labels.view(-1)
            mask_flat = span_mask.view(-1).bool()

            loss = self._loss(ner_scores_flat[mask_flat],
                              ner_labels_flat[mask_flat])
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["document"] = [x["sentence"] for x in metadata]

        return output_dict
Exemplo n.º 9
0
    def forward(
            self,  # type: ignore
            spans: torch.IntTensor,
            span_mask: torch.IntTensor,
            span_embeddings: torch.IntTensor,
            sentence_lengths: torch.Tensor,
            span_labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings
        span_scores = self._span_scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        span_scores = util.replace_masked_values(span_scores, mask, -1e20)
        span_scores[:, :, 0] *= span_mask

        _, predicted_span = span_scores.max(2)

        output_dict = {
            "spans": spans,
            "span_mask": span_mask,
            "span_scores": span_scores,
            "predicted_span": predicted_span
        }

        if span_labels is not None:
            self._span_metrics(predicted_span, span_labels, span_mask)
            span_scores_flat = span_scores.view(-1, self._n_labels)
            span_labels_flat = span_labels.view(-1)
            mask_flat = span_mask.view(-1).bool()

            loss = self._loss(span_scores_flat[mask_flat],
                              span_labels_flat[mask_flat])
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["document"] = [x["sentence"] for x in metadata]

        return output_dict
Exemplo n.º 10
0
    def forward(self, input_word_index: torch.IntTensor,
                h_state: torch.FloatTensor, c_state: torch.FloatTensor):
        """
        Pass inputs through the model.

        Args:
            input_word_index: torch.IntTensor[batch_size,]
            h_state: torch.FloatTensor[n_layer, batch_size, hidden_size]
            c_state: torch.FloatTensor[n_layer, batch_size, hidden_size]

        Returns:
            logit: torch.FloatTensor[batch_size, vocab_size]
            h_state: torch.FloatTensor[n_layer, batch_size, hidden_size]
            c_state: torch.FloatTensor[n_layer, batch_size, hidden_size]
        """
        embedded = self.embedding(input_word_index.unsqueeze(0))
        embedded = F.dropout(embedded, p=self.embedding_dropout)
        output, (h_state, c_state) = self.lstm(embedded, (h_state, c_state))
        logit = self.fc(output.squeeze(0))
        return logit, (h_state, c_state)
Exemplo n.º 11
0
    def forward(
            ctx, e1, e2, e3: float, nsample: int, xyz: torch.Tensor,
            new_xyz: torch.Tensor, fps_idx: torch.IntTensor
    ) -> torch.Tensor:
        r"""

        Parameters , ingroup_pts_cnt: torch.IntTensor, ingroup_out: torch.Tensor, ingroup_cva: torch.Tensor, v: torch.Tensor, d: torch.Tensor
        ----------
        e1, e2, e3 : float
            e1, e2, e3 of the ellipsoid
        nsample : int
            maximum number of features in the balls
        xyz : torch.Tensor
            (B, N, 3) xyz coordinates of the features
        new_xyz : torch.Tensor
            (B, npoint, 3) centers of the ball query

        Returns
        -------
        torch.Tensor
            (B, npoint, nsample) tensor with the indicies of the features that form the query balls
        """
        assert new_xyz.is_contiguous()
        assert xyz.is_contiguous()

        B, N, _ = xyz.size()
        npoint = new_xyz.size(1)
        idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
        ingroup_pts_cnt = torch.cuda.IntTensor(B, npoint).zero_()
        ingroup_out = torch.cuda.FloatTensor(B, npoint, nsample, 3).zero_()
        ingroup_cva = torch.cuda.FloatTensor(B, npoint, 3*3).zero_()
        v = torch.cuda.FloatTensor(B, npoint, 3*3).zero_()
        d = torch.cuda.FloatTensor(B, npoint, 3).zero_()

        pointnet2.ellipsoid_query_wrapper(
            B, N, npoint, e1, e2, e3, nsample, new_xyz, xyz, fps_idx, idx, ingroup_pts_cnt, ingroup_out, ingroup_cva, v, d
        )
        
        return torch.cat([fps_idx.unsqueeze(2), idx], dim = 2),d
Exemplo n.º 12
0
def batched_gather(x: torch.Tensor, indices: torch.IntTensor, dim: int):
    """
    Similar to the gather method of :class:`torch.Tensor`.

    Args:
        x: the tensor to select.
        indices: the indices to choose.
        dim: the dimension to choose.

    Returns:
        A selected tensor
    """

    if indices.dim() == 1:
        return x[indices]
    elif indices.dim() == 2:
        if x.dim() > indices.dim():
            indices = indices.unsqueeze(-1).repeat_interleave(x.shape[-1],
                                                              dim=-1)

        return x.gather(dim, indices)

    raise NotImplementedError(
        "Currently do not support more batch dimensions than 1!")
Exemplo n.º 13
0
def rnnt_loss(log_probs: torch.FloatTensor,
              labels: torch.IntTensor,
              frames_lengths: torch.IntTensor,
              labels_lengths: torch.IntTensor,
              average_frames: bool = False,
              reduction: Optional[AnyStr] = None,
              blank: int = 0,
              gather: bool = False) -> torch.Tensor:

    """The CUDA-Warp RNN-Transducer loss.

    Args:
        log_probs (torch.FloatTensor): Input tensor with shape (N, T, U, V)
            where N is the minibatch size, T is the maximum number of
            input frames, U is the maximum number of output labels and V is
            the vocabulary of labels (including the blank).
        labels (torch.IntTensor): Tensor with shape (N, U-1) representing the
            reference labels for all samples in the minibatch.
        frames_lengths (torch.IntTensor): Tensor with shape (N,) representing the
            number of frames for each sample in the minibatch.
        labels_lengths (torch.IntTensor): Tensor with shape (N,) representing the
            length of the transcription for each sample in the minibatch.
        average_frames (bool, optional): Specifies whether the loss of each
            sample should be divided by its number of frames.
            Default: False.
        reduction (string, optional): Specifies the type of reduction.
            Default: None.
        blank (int, optional): label used to represent the blank symbol.
            Default: 0.
        gather (bool, optional): Reduce memory consumption.
            Default: False.
    """

    assert average_frames is None or isinstance(average_frames, bool)
    assert reduction is None or reduction in ("none", "mean", "sum")
    assert isinstance(blank, int)
    assert isinstance(gather, bool)

    assert not labels.requires_grad, "labels does not require gradients"
    assert not frames_lengths.requires_grad, "frames_lengths does not require gradients"
    assert not labels_lengths.requires_grad, "labels_lengths does not require gradients"

    if gather:

        N, T, U, V = log_probs.size()

        index = torch.full([N, T, U, 2], blank, device=labels.device, dtype=torch.long)

        index[:, :, :U-1, 1] = labels.unsqueeze(dim=1)

        log_probs = log_probs.gather(dim=3, index=index)

        blank = -1

    costs = RNNTLoss.apply(log_probs, labels, frames_lengths, labels_lengths, blank)

    if average_frames:
        costs = costs / frames_lengths.to(log_probs)

    if reduction == "sum":
        return costs.sum()
    elif reduction == "mean":
        return costs.mean()
    return costs
Exemplo n.º 14
0
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        metadata: List[Dict[str, Any]],
        doc_span_offsets: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        doc_truth_spans: torch.IntTensor = None,
        doc_spans_in_truth: torch.IntTensor = None,
        doc_relation_labels: torch.Tensor = None,
        truth_spans: List[Set[Tuple[int, int]]] = None,
        doc_relations=None,
        doc_ner_labels: torch.IntTensor = None,
    ) -> Dict[str, torch.Tensor]:  # add matrix from datareader
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        doc_ner_labels : ``torch.IntTensor``.
            A tensor of shape # TODO,
            ...
        doc_span_offsets : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_truth_spans : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, 1),
            ...
        doc_spans_in_truth : ``torch.IntTensor``.
            A tensor of shape (batch_size, max_sentences, max_spans_per_sentence, 1),
            ...
        doc_relation_labels : ``torch.Tensor``.
            A tensor of shape (batch_size, max_sentences, max_truth_spans, max_truth_spans),
            ...

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        batch_size = len(spans)
        document_length = text_embeddings.size(1)
        max_sentence_length = max(
            len(sentence) for document in metadata
            for sentence in document['doc_tokens'])
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # TODO features dropout
        # Shape: (batch_size, num_spans, embedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_relex_spans_to_keep = int(
            math.floor(self._relex_spans_per_word * max_sentence_length))

        # Shapes:
        # (batch_size, num_spans_to_keep, span_dim),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep),
        # (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # Shape: (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings, top_span_mention_scores,
            candidate_antecedent_mention_scores, valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = dict()

        output_dict["top_spans"] = top_spans
        output_dict["antecedent_indices"] = valid_antecedent_indices
        output_dict["predicted_antecedents"] = predicted_antecedents

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]

        # Shape: (,)
        loss = 0

        # Shape: (batch_size, max_sentences, max_spans)
        doc_span_mask = (doc_span_offsets[:, :, :, 0] >= 0).float()
        # Shape: (batch_size, max_sentences, num_spans, span_dim)
        doc_span_embeddings = util.batched_index_select(
            span_embeddings,
            doc_span_offsets.squeeze(-1).long().clamp(min=0))

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep),
        # (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        pruned = self._relex_mention_pruner(
            doc_span_embeddings,
            doc_span_mask,
            num_items_to_keep=num_relex_spans_to_keep,
            pass_through=['num_items_to_keep'])
        (top_relex_span_embeddings, top_relex_span_mask,
         top_relex_span_indices, top_relex_span_mention_scores) = pruned

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1)
        top_relex_span_mask = top_relex_span_mask.unsqueeze(-1)

        # Shape: (batch_size, max_sentences, max_spans_per_sentence, 2)  # TODO do we need for a mask?
        doc_spans = util.batched_index_select(
            spans,
            doc_span_offsets.clamp(0).squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 2)
        top_relex_spans = nd_batched_index_select(doc_spans,
                                                  top_relex_span_indices)

        # Shapes:
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, 3 * span_dim),
        # (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep).
        (relex_span_pair_embeddings,
         relex_span_pair_mask) = self._compute_relex_span_pair_embeddings(
             top_relex_span_embeddings, top_relex_span_mask.squeeze(-1))

        # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep, num_relation_labels)
        relex_scores = self._compute_relex_scores(
            relex_span_pair_embeddings, top_relex_span_mention_scores)
        output_dict['relex_scores'] = relex_scores
        output_dict['top_relex_spans'] = top_relex_spans

        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)
            antecedent_labels_ = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels = antecedent_labels_ + valid_antecedent_log_mask.long(
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability x to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs)
            negative_marginal_log_likelihood *= top_span_mask.squeeze(
                -1).float()
            negative_marginal_log_likelihood = negative_marginal_log_likelihood.sum(
            )

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            coref_loss = negative_marginal_log_likelihood
            output_dict['coref_loss'] = coref_loss
            loss += self._loss_coref_weight * coref_loss

        if doc_relations is not None:

            # The adjacency matrix for relation extraction is very sparse.
            # As it is not just sparse, but row/column sparse (only few
            # rows and columns are non-zero and in that case these rows/columns
            # are not sparse), we implemented our own matrix for the case.
            # Here we have indices of truth spans and mapping, using which
            # we map prediction matrix on truth matrix.
            # TODO Add teacher forcing support.

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep),
            relative_indices = top_relex_span_indices
            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, 1),
            compressed_indices = nd_batched_padded_index_select(
                doc_spans_in_truth, relative_indices)

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, max_truth_spans)
            gold_pruned_rows = nd_batched_padded_index_select(
                doc_relation_labels,
                compressed_indices.squeeze(-1),
                padding_value=0)
            gold_pruned_rows = gold_pruned_rows.permute(0, 1, 3,
                                                        2).contiguous()

            # Shape: (batch_size, max_sentences, num_relex_spans_to_keep, num_relex_spans_to_keep)
            gold_pruned_matrices = nd_batched_padded_index_select(
                gold_pruned_rows,
                compressed_indices.squeeze(-1),
                padding_value=0)  # pad with epsilon
            gold_pruned_matrices = gold_pruned_matrices.permute(
                0, 1, 3, 2).contiguous()

            # TODO log_mask relex score before passing
            relex_loss = nd_cross_entropy_with_logits(relex_scores,
                                                      gold_pruned_matrices,
                                                      relex_span_pair_mask)
            output_dict['relex_loss'] = relex_loss

            self._relex_mention_recall(top_relex_spans.view(batch_size, -1, 2),
                                       truth_spans)
            self._compute_relex_metrics(output_dict, doc_relations)

            loss += self._loss_relex_weight * relex_loss

        if doc_ner_labels is not None:
            # Shape: (batch_size, max_sentences, num_spans, num_ner_classes)
            ner_scores = self._ner_scorer(doc_span_embeddings)
            output_dict['ner_scores'] = ner_scores

            ner_loss = nd_cross_entropy_with_logits(ner_scores, doc_ner_labels,
                                                    doc_span_mask)
            output_dict['ner_loss'] = ner_loss
            loss += self._loss_ner_weight * ner_loss

        if not isinstance(loss, int):  # If loss is not yet modified
            output_dict["loss"] = loss

        return output_dict
Exemplo n.º 15
0
    def compute_representations(
        self,  # type: ignore
        span_embeddings,  # (1, Ns, E)
        coref_labels: torch.IntTensor,  # (1, Ns, C)
        type_to_cluster_ids: Dict[str, List[int]],
        relation_to_cluster_ids: Dict[int, List[int]] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        if coref_labels.sum() == 0:
            return {"loss": 0.0, "metadata": metadata}

        cluster_type_embeddings = self.map_cluster_to_type_embeddings(
            type_to_cluster_ids)  # (1, C, E)

        sum_embeddings = (span_embeddings.unsqueeze(2) *
                          coref_labels.float().unsqueeze(-1)).sum(1)
        length_embeddings = (coref_labels.unsqueeze(-1).sum(1) + 1e-5)

        cluster_span_embeddings = sum_embeddings / length_embeddings

        paragraph_cluster_mask = (coref_labels.sum(1) > 0).float().unsqueeze(
            -1)  # (P, C, 1)

        paragraph_cluster_embeddings = cluster_span_embeddings * paragraph_cluster_mask + cluster_type_embeddings * (
            1 - paragraph_cluster_mask)  # (P, C, E)

        assert (paragraph_cluster_embeddings.shape[1] == coref_labels.shape[2]
                and paragraph_cluster_embeddings.shape[2]
                == span_embeddings.shape[-1])

        paragraph_cluster_embeddings = torch.cat(
            [
                paragraph_cluster_embeddings,
                self._bias_vectors.expand(
                    paragraph_cluster_embeddings.shape[0], -1, -1)
            ],
            dim=1,
        )  # (P, C+4, E)
        n_true_clusters = coref_labels.shape[-1]

        candidate_relations, candidate_relations_labels, candidate_relations_types = self.generate_product(
            type_to_clusters_map=type_to_cluster_ids,
            relation_to_clusters_map=relation_to_cluster_ids,
            n_true_clusters=n_true_clusters,
        )

        candidate_relations_tensor = torch.LongTensor(candidate_relations).to(
            span_embeddings.device)  # (R, 4)
        candidate_relations_labels_tensor = torch.LongTensor(
            candidate_relations_labels).to(span_embeddings.device)  # (R, )

        if len(candidate_relations) == 0:
            return {"loss": 0.0, "metadata": metadata}

        all_relation_embeddings = util.batched_index_select(
            paragraph_cluster_embeddings,
            candidate_relations_tensor.unsqueeze(0).expand(
                paragraph_cluster_embeddings.shape[0], -1, -1),
        )  # (P, R', n, E)

        relation_scores, relation_logits = self.get_relation_scores(
            all_relation_embeddings)  # (1, R')
        output_dict = {}
        output_dict["relations_candidates_list"] = candidate_relations
        output_dict["relation_labels"] = candidate_relations_labels
        output_dict["relation_types"] = candidate_relations_types
        output_dict["doc_id"] = metadata[0]["doc_id"]
        output_dict["metadata"] = metadata
        output_dict["relation_scores"] = relation_scores
        output_dict["relation_logits"] = relation_logits

        if relation_to_cluster_ids is not None:
            output_dict = self.predict_labels(
                relation_scores, relation_logits,
                candidate_relations_labels_tensor, output_dict)

        return output_dict
Exemplo n.º 16
0
    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            target_word: torch.IntTensor,
            gold_label: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """

        Parameters
        ----------
        tokens:

        target_word:
            (batch_size, 2)
        gold_label:
            (batch_size)
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            metadata containing the original words in the sentence to be tagged under a 'words' key.

        Returns
        -------
        An output dictionary consisting of:

        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """

        # Shape: (batch_size, sentence_length, embedding_size)
        tokens_embeddings = self._lexical_dropout(
            self._text_field_embedder(tokens))

        # Shape: (batch_size, sentence_length)
        tokens_mask = util.get_text_field_mask(tokens).float()

        # Shape: (batch_size, sentence_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            tokens_embeddings, tokens_mask)

        # Shape: (batch_size, 2 * encoding_dim)
        target_word_embeddings = self._target_word_extractor(
            contextualized_embeddings, target_word)

        # Shape: (batch_size, 1)
        complex_word_logits = self._complex_word_scorer(target_word_embeddings)

        complex_word_predictions = complex_word_logits > 0.5

        output_dict = {
            "logits": complex_word_logits,
            "predictions": complex_word_predictions
        }

        if gold_label is not None:
            output_dict["loss"] = self._loss(complex_word_logits,
                                             gold_label.unsqueeze(-1).float())

            macro_F1 = metrics.f1_score(gold_label,
                                        complex_word_predictions,
                                        average='macro')

            self._metric(complex_word_predictions, gold_label)

        return output_dict
Exemplo n.º 17
0
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            spans: torch.IntTensor,
            labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1])

        # Shape:    (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Shape: (batch_size, num_spans_to_keep, class_num + 1)
        ne_scores = self._compute_named_entity_scores(top_span_embeddings)

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_named_entities = ne_scores.max(2)

        output_dict = {
            "top_spans": top_spans,
            "predicted_named_entities": predicted_named_entities
        }
        if labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices).squeeze(-1)
            negative_log_likelihood = F.cross_entropy(
                ne_scores.reshape(-1, self.class_num),
                pruned_gold_labels.reshape(-1))

            pruner_loss = F.binary_cross_entropy_with_logits(
                top_span_mention_scores.reshape(-1),
                (pruned_gold_labels.reshape(-1) != 0).float())
            loss = negative_log_likelihood + pruner_loss
            output_dict["loss"] = loss
            output_dict["pruner_loss"] = pruner_loss
            batch_size, _ = labels.shape
            all_scores = ne_scores.new_zeros(
                [batch_size * num_spans, self.class_num])
            all_scores[:, 0] = 1
            all_scores[flat_top_span_indices] = ne_scores.reshape(
                -1, self.class_num)
            all_scores = all_scores.reshape(
                [batch_size, num_spans, self.class_num])
            self._metric_all(all_scores, labels)
            self._metric_avg(all_scores, labels)
        return output_dict
Exemplo n.º 18
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.last_dim_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood
        return output_dict
Exemplo n.º 19
0
    def forward(
            self,  # type: ignore
            para_id: int,
            participant_strings: List[str],
            paragraph: Dict[str, torch.LongTensor],
            sentences: Dict[str, torch.LongTensor],
            paragraph_sentence_indicators: torch.IntTensor,
            participants: Dict[str, torch.LongTensor],
            participant_indicators: torch.IntTensor,
            paragraph_participant_indicators: torch.IntTensor,
            verbs: torch.IntTensor,
            paragraph_verbs: torch.IntTensor,
            actions: torch.IntTensor = None,
            before_locations: torch.IntTensor = None,
            after_locations: torch.IntTensor = None,
            filename: List[str] = [],
            score: List[float] = 1.0  # instance_score
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        para_id: int
            The id of the paragraph
        participant_strings: List[str]
            The participants in the paragraph
        paragraph: Dict[str, torch.LongTensor]
            The token indices for the paragraph
        sentences: Dict[str, torch.LongTensor]
            The token indices batched by sentence.
        paragraph_sentence_indicators: torch.LongTensor
            Indicates before / inside / after for each sentence
        participants: Dict[str, torch.LongTensor]
            The token indices for the participant names
        participant_indicators: torch.IntTensor
            Indicates each participant in each sentence
        paragraph_participant_indicators: torch.IntTensor
            Indicates each participant in the paragraph
        verbs: torch.IntTensor
            Indicates the positions of verbs in the sentences
        paragraph_verbs: torch.IntTensor
            Indicates the positions of verbs in the paragraph
        actions: torch.IntTensor, optional (default = None)
            Indicates the actions taken per participant
            per sentence.
        before_locations: torch.IntTensor, optional (default = None)
            Indicates the span for the before location
            per participant per sentence
        after_locations: torch.IntTensor, optional (default = None)
            Indicates the span for the after location
            per participant per sentence
        filename: List[str], optional (default = '')
            The files from which the instances were read
        score: List[float], optional (default = 1.0)
            The score for each instance

        Returns
        -------
        An output dictionary consisting of:
        action_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_sentences, num_participants, num_action_types)`` representing
            a distribution of state change types per sentence, participant in each datapoint (paragraph).
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        self.filename = filename
        self.instance_score = score

        # original shape (batch_size, num_participants, num_sentences, sentence_length)
        participant_indicators = participant_indicators.transpose(1, 2)
        # new shape (batch_size, num_sentences, num_participants, sentence_length)

        batch_size, num_sentences, num_participants, sentence_length = participant_indicators.size(
        )

        # (batch_size, num_sentences, sentence_length, embedding_size)
        embedded_sentences = self.text_field_embedder(sentences)
        # (batch_size, num_participants, description_length, embedding_size)
        embedded_participants = self.text_field_embedder(participants)

        batch_size, num_sentences, sentence_length, embedding_size = embedded_sentences.size(
        )
        self.num_sentences = num_sentences

        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # (batch_size, num_sentences, num_participants, sentence_length, embedding_size)
        embedded_sentence_participant_pairs = embedded_sentences.unsqueeze(2).expand(batch_size, num_sentences, \
                                                                                     num_participants, sentence_length,
                                                                                     embedding_size)

        # (batch_size, num_sentences, sentence_length) -> (batch_size, num_sentences, num_participants, sentence_length)
        mask = get_text_field_mask(sentences, num_wrapping_dims=1). \
            unsqueeze(2).expand(batch_size, num_sentences, num_participants, sentence_length).float()

        # (batch_size, num_participants, num_sentences * sentence_length)
        participant_view = participant_indicators.transpose(1, 2). \
            view(batch_size, num_participants, num_sentences * sentence_length)

        # participant_mask is used to mask out invalid sentence, participant pairs
        # (batch_size, num_sentences, num_participants, sentence_length)
        sent_participant_pair_mask = (participant_view.sum(dim=2) > 0). \
            unsqueeze(-1).expand(batch_size, num_participants, num_sentences). \
            unsqueeze(-1).expand(batch_size, num_participants, num_sentences, sentence_length). \
            transpose(1, 2).float()

        # whether the sentence is masked or not (sent does not exist in paragraph).
        # this is either (batch_size, num_sentences, num_participants)
        # or if only one participant (batch_size, num_sentences)
        # TODO(joelgrus) why is there a squeeze here
        sentence_mask = (mask.sum(3) > 0).squeeze(-1).float()

        # (batch_size, num_sentences, num_participants, sentence_length)
        mask = mask * sent_participant_pair_mask

        # (batch_size, num_participants, num_sentences * sentence_length)
        # -> (batch_size, num_participants)
        # -> (batch_size, num_participants, num_sentences)
        # -> (batch_size, num_sentences, num_participants)
        participant_mask = (participant_view.sum(dim=2) > 0). \
            unsqueeze(-1).expand(batch_size, num_participants, num_sentences). \
            transpose(1, 2).float()

        # Example: 0.0 where action is -1 (padded)
        # action:  [[[1, 0, 1], [3, 2, 3]], [[0, -1, -1], [-1, -1, -1]]]
        # action_mask:  [[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]
        # (batch_size, num_sentences, num_participants)
        action_mask = participant_mask * sentence_mask

        # (batch_size, num_sentences, num_participants, sentence_length)
        verb_indicators = verbs.unsqueeze(2).expand(batch_size, num_sentences,
                                                    num_participants,
                                                    sentence_length).float()

        # ===========================================================================================================
        # Layer 2: Concatenate sentence embedding with verb and participant indicator bits
        # espp: (batch_size, num_sentences, num_participants, sentence_length, embedding_size)
        # vi:   (batch_size, num_sentences, num_participants, sentence_length)
        # pi:   (batch_size, num_sentences, num_participants, sentence_length)
        #
        # result: (batch_size, num_sentences, num_participants, sentence_length, embedding_size + 2)
        embedded_sentence_verb_entity = \
            torch.cat([embedded_sentence_participant_pairs, verb_indicators.unsqueeze(-1).float(),
                       participant_indicators.unsqueeze(-1).float()], dim=-1)

        # ===========================================================================================================
        # Layer 3 = Contextual embedding layer using Bi-LSTM over the sentence

        if self.use_attention:
            # (batch_size, num_sentences, num_participants, sentence_length, )
            # contextual_seq_embedding: batch_size * num_sentences *
            #                                num_participants * sentence_length * (2*seq2seq_output_size)
            contextual_seq_embedding = self.time_distributed_seq2seq_encoder(
                embedded_sentence_verb_entity, mask)

            # Layer 3.5: Attention (Contextual embedding, BOW(verb span))
            verb_weight_matrix = verb_indicators.float() / (
                verb_indicators.float().sum(-1).unsqueeze(-1) + 1e-13)
            # (batch_size, num_sentences, num_participants, embedding_size)
            verb_vector = weighted_sum(
                contextual_seq_embedding *
                verb_indicators.float().unsqueeze(-1), verb_weight_matrix)

            # (batch_size, num_sentences, num_participants, sentence_length)
            participant_weight_matrix = participant_indicators.float() / (
                participant_indicators.float().sum(-1).unsqueeze(-1) + 1e-13)

            # (batch_size, num_sentences, num_participants, embedding_size)
            participant_vector = weighted_sum(
                contextual_seq_embedding *
                participant_indicators.float().unsqueeze(-1),
                participant_weight_matrix)

            # (batch_size, num_sentences, num_participants, 2 * embedding_size)
            verb_participant_vector = torch.cat(
                [verb_vector, participant_vector], -1)
            batch_size, num_sentences, num_participants, sentence_length, verb_ind_size = verb_indicators.float(
            ).unsqueeze(-1).size()

            # attention weights for type prediction
            # (batch_size, num_sentences, num_participants)
            attention_weights_actions = self.time_distributed_attention_layer(
                verb_participant_vector, contextual_seq_embedding, mask)
            contextual_vec_embedding = weighted_sum(contextual_seq_embedding,
                                                    attention_weights_actions)

        else:
            # batch_size * num_sentences * num_participants * sentence_length * embedding_size
            contextual_vec_embedding = self.time_distributed_seq2vec_encoder(
                embedded_sentence_verb_entity, mask)

        # (batch_size, num_participants, num_sentences, 1) -> (batch_size, nnum_sentences, num_participants, 1)
        if actions is not None:
            actions = actions.transpose(1, 2)

        # # ===========================================================================================================
        # # Layer 4 = Aggregate FeedForward to choose an action label per sentence, participant pair
        # (batch_size, num_sentences, num_participants, num_actions)
        action_logits = self.aggregate_feedforward(contextual_vec_embedding)

        action_probs = torch.nn.functional.softmax(action_logits, dim=-1)
        # (batch_size * num_sentences * num_participants, num_actions)
        action_probs_decode = action_probs.view(
            (batch_size * num_sentences * num_participants), self.num_actions)

        output_dict = {}
        if self.use_decoder_trainer:
            # (batch_size, num_participants, description_length, embedding_size)
            participants_list = embedded_participants.data.cpu().numpy()

            output_dict.update(
                DecoderTrainerHelper.pass_on_info_to_decoder_trainer(
                    selfie=self,
                    para_id_list=para_id,
                    actions=actions,
                    target_mask=action_mask,
                    participants_list=participants_list,
                    participant_strings=participant_strings,
                    participant_indicators=participant_indicators.transpose(
                        1, 2),
                    logit_tensor=action_logits))

            # Compute type_accuracy based on best_final_states and actions
            best_decoded_state = output_dict['best_final_states'][0][0][0]
            best_decoded_action_seq = []
            if best_decoded_state.action_history:
                for cur_step_action in best_decoded_state.action_history[0]:
                    step_predictions = []
                    for step_action in list(cur_step_action):
                        step_predictions.append(step_action)
                    best_decoded_action_seq.append(step_predictions)
                best_decoded_tensor = torch.LongTensor(
                    best_decoded_action_seq).unsqueeze(0)

                if actions is not None:
                    flattened_gold = actions.long().contiguous().view(-1)
                    self._type_accuracy(
                        best_decoded_tensor.long().contiguous().view(-1),
                        flattened_gold)
            output_dict['best_decoded_action_seq'] = [best_decoded_action_seq]
        else:
            # Create output dictionary for the trainer
            # Compute loss and epoch metrics
            output_dict["action_probs"] = action_probs
            output_dict["action_probs_decode"] = action_probs_decode

            action_loss = 0.0
            location_loss = 0.0
            if actions is not None:
                # (batch_size * num_sentences * num_participants, num_actions)
                flattened_predictions = action_logits.view(
                    (batch_size * num_sentences * num_participants),
                    self.num_actions)
                # Flattened_gold: contains the gold action index (Action enum in propara_dataset_reader)
                # Note: tensor is not a single block of memory, but a block with holes.
                # view can be only used with contiguous tensors, so if you need to use it here, just call .contiguous() before.
                # (batch_size * num_sentences * num_participants)
                flattened_gold = actions.long().contiguous().view(-1)
                action_loss = self._loss(flattened_predictions, flattened_gold)
                flattened_probs = action_probs.view(
                    (batch_size * num_sentences * num_participants),
                    self.num_actions)
                evaluation_mask = (flattened_gold != -1)

                self._type_accuracy(flattened_probs,
                                    flattened_gold,
                                    mask=evaluation_mask)
                output_dict["loss"] = action_loss

        best_span_after, span_start_logits_after, span_end_logits_after = \
            self.compute_location_spans(contextual_seq_embedding=contextual_seq_embedding,
                                        embedded_sentence_verb_entity=embedded_sentence_verb_entity,
                                        mask=mask)
        output_dict["location_span_after"] = [best_span_after]

        not_in_test = (self.training or 'test' not in self.filename)

        if not_in_test and (before_locations is not None
                            and after_locations is not None):
            after_locations = after_locations.transpose(1, 2)

            (bs, ns, np, sl) = span_start_logits_after.size()
            #print("after_locations[:,:,:,[0]]:", after_locations[:,:,:,[0]])

            location_mask = (after_locations[:, :, :, 0] >=
                             0).float().unsqueeze(-1).expand(bs, ns, np, sl)

            #print("location_mask:", location_mask)

            start_after_log_predicted = util.masked_log_softmax(
                span_start_logits_after, location_mask)
            start_after_log_predicted_transpose = start_after_log_predicted.transpose(
                2, 3).transpose(1, 2)
            start_after_gold = torch.clamp(after_locations[:, :, :,
                                                           [0]].squeeze(-1),
                                           min=-1)
            #print("start_after_log_predicted_transpose: ", start_after_log_predicted_transpose)
            #print("start_after_gold: ", start_after_gold)
            location_loss = nll_loss(input=start_after_log_predicted_transpose,
                                     target=start_after_gold,
                                     ignore_index=-1)

            end_after_log_predicted = util.masked_log_softmax(
                span_end_logits_after, location_mask)
            end_after_log_predicted_transpose = end_after_log_predicted.transpose(
                2, 3).transpose(1, 2)
            end_after_gold = torch.clamp(after_locations[:, :, :,
                                                         [1]].squeeze(-1),
                                         min=-1)
            #print("end_after_log_predicted_transpose: ", end_after_log_predicted_transpose)
            #print("end_after_gold: ", end_after_gold)
            location_loss += nll_loss(input=end_after_log_predicted_transpose,
                                      target=end_after_gold,
                                      ignore_index=-1)
            output_dict["loss"] += location_loss
            # output_dict = {"loss" : 0.0}

        output_dict['action_probs_decode'] = action_probs_decode
        output_dict['action_logits'] = action_logits
        return output_dict
Exemplo n.º 20
0
def create_attended_span_representations(
        max_span_width: int, head_scores: torch.FloatTensor,
        encoded_text: torch.FloatTensor, span_ends: torch.IntTensor,
        span_widths: torch.IntTensor) -> torch.FloatTensor:
    """
    Given a tensor of unnormalized attention scores for each word in the document, compute
    distributions over every span with respect to these scores by normalising the headedness
    scores for words inside the span.

    Given these headedness distributions over every span, weight the corresponding vector
    representations of the words in the span by this distribution, returning a weighted
    representation of each span.

    Parameters
    ----------
    head_scores : ``torch.FloatTensor``, required.
        Unnormalized headedness scores for every word. This score is shared for every
        candidate. The only way in which the headedness scores differ over different
        spans is in the set of words over which they are normalized.
    text_embeddings: ``torch.FloatTensor``, required.
        The embeddings with shape  (batch_size, document_length, embedding_size)
        over which we are computing a weighted sum.
    span_ends: ``torch.IntTensor``, required.
        A tensor of shape (batch_size, num_spans), representing the end indices
        of each span.
    span_widths : ``torch.IntTensor``, required.
        A tensor of shape (batch_size, num_spans) representing the width of each
        span candidates.
    Returns
    -------
    attended_text_embeddings : ``torch.FloatTensor``
        A tensor of shape (batch_size, num_spans, embedding_dim) - the result of
        applying attention over all words within each candidate span.
    """
    # Shape: (1, 1, max_span_width)
    max_span_range_indices = util.get_range_vector(max_span_width,
                                                   encoded_text.is_cuda).view(
                                                       1, 1, -1)

    # Shape: (batch_size, num_spans, max_span_width)
    # This is a broadcasted comparison - for each span we are considering,
    # we are creating a range vector of size max_span_width, but masking values
    # which are greater than the actual length of the span.

    span_ends = span_ends.unsqueeze(-1)
    span_widths = span_widths.unsqueeze(-1)
    span_mask = (max_span_range_indices <= span_widths).float()
    raw_span_indices = span_ends - max_span_range_indices
    # We also don't want to include span indices which are less than zero,
    # which happens because some spans near the beginning of the document
    # are of a smaller width than max_span_width, so we add this to the mask here.
    span_mask = span_mask * (raw_span_indices >= 0).float()
    # Spans
    span_indices = F.relu(raw_span_indices.float()).long()

    # Shape: (batch_size * num_spans * max_span_width)
    flat_span_indices = util.flatten_and_batch_shift_indices(
        span_indices, encoded_text.size(1))

    # Shape: (batch_size, num_spans, max_span_width, embedding_dim)
    span_text_embeddings = util.batched_index_select(encoded_text,
                                                     span_indices,
                                                     flat_span_indices)

    # Shape: (batch_size, num_spans, max_span_width)
    span_head_scores = util.batched_index_select(head_scores, span_indices,
                                                 flat_span_indices).squeeze(-1)

    # Shape: (batch_size, num_spans, max_span_width)
    span_head_weights = util.last_dim_softmax(span_head_scores, span_mask)

    # Do a weighted sum of the embedded spans with
    # respect to the normalised head score distributions.
    # Shape: (batch_size, num_spans, embedding_dim)
    attended_text_embeddings = util.weighted_sum(span_text_embeddings,
                                                 span_head_weights)

    return attended_text_embeddings
Exemplo n.º 21
0
    def forward(
        self,  # type: ignore
        spans: torch.IntTensor,
        span_mask: torch.IntTensor,
        span_embeddings: torch.IntTensor,
        sentence_lengths: torch.Tensor,
        ner_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
        previous_step_output: Dict[str,
                                   Any] = None) -> Dict[str, torch.Tensor]:
        """
        TODO(dwadden) Write documentation.
        """

        # Shape: (Batch size, Number of Spans, Span Embedding Size)
        # span_embeddings
        ner_scores = self._ner_scorer(span_embeddings)
        # Give large negative scores to masked-out elements.
        mask = span_mask.unsqueeze(-1)
        ner_scores = util.replace_masked_values(ner_scores, mask, -1e20)
        dummy_dims = [ner_scores.size(0), ner_scores.size(1), 1]
        dummy_scores = ner_scores.new_zeros(*dummy_dims)
        if previous_step_output is not None and "predicted_span" in previous_step_output and not self.training:
            dummy_scores.masked_fill_(
                previous_step_output["predicted_span"].bool().unsqueeze(-1),
                -1e20)
            dummy_scores.masked_fill_(
                (1 -
                 previous_step_output["predicted_span"]).bool().unsqueeze(-1),
                1e20)

        ner_scores = torch.cat((dummy_scores, ner_scores), -1)

        if previous_step_output is not None and "predicted_seq_span" in previous_step_output and not self.training:
            for row_idx, all_spans in enumerate(spans):
                pred_spans = previous_step_output["predicted_seq_span"][
                    row_idx]
                pred_spans = all_spans.new_tensor(pred_spans)
                for col_idx, span in enumerate(all_spans):
                    if span_mask[row_idx][col_idx] == 0:
                        continue
                    bFind = False
                    for pred_span in pred_spans:
                        if span[0] == pred_span[0] and span[1] == pred_span[1]:
                            bFind = True
                            break
                    if bFind:
                        # if find, use the ner scores, set dummy to a big negative
                        ner_scores[row_idx, col_idx, 0] = -1e20
                    else:
                        # if not find, use the previous step, set dummy to a big positive
                        ner_scores[row_idx, col_idx, 0] = 1e20

        _, predicted_ner = ner_scores.max(2)

        output_dict = {
            "spans": spans,
            "span_mask": span_mask,
            "ner_scores": ner_scores,
            "predicted_ner": predicted_ner
        }

        if ner_labels is not None:
            self._ner_metrics(predicted_ner, ner_labels, span_mask)
            ner_scores_flat = ner_scores.view(-1, self._n_labels)
            ner_labels_flat = ner_labels.view(-1)
            mask_flat = span_mask.view(-1).bool()

            loss = self._loss(ner_scores_flat[mask_flat],
                              ner_labels_flat[mask_flat])
            output_dict["loss"] = loss

        if metadata is not None:
            output_dict["document"] = [x["sentence"] for x in metadata]

        return output_dict
Exemplo n.º 22
0
    def forward(
        self,  # type: ignore
        sentences: torch.LongTensor,
        labels: torch.IntTensor = None,
        confidences: torch.Tensor = None,
        additional_features: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences

        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = embedded_sentences.size()

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            sentences_mask = sentences[
                'bert'] == 103  # mask for all the SEP tokens in the batch
            embedded_sentences = embedded_sentences[
                sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
            # returns num_sentences_per_batch x vector_len
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(dim=0)
            embedded_sentences = self.dropout(embedded_sentences)

            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[
                    labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(
                        f'Found {num_labels} labels but {num_sentences} sentences'
                    )
                    labels = labels[:
                                    num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                    # We are ignoring this problem for now.
                    # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:
                                                                  num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = embedded_sentences.size()
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat(
                (embedded_sentences, additional_features), dim=-1)

        label_logits = self.time_distributed_aggregate_feedforward(
            embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            flattened_logits = label_logits.view((batch_size * num_sentences),
                                                 self.num_labels)
            flattened_gold = labels.contiguous().view(-1)

            if not self.with_crf:
                label_loss = self.loss(flattened_logits.squeeze(),
                                       flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(
                        label_loss).view(-1)
                label_loss = label_loss.mean()
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels,
                                          mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view(
                    (batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                evaluation_mask = (flattened_gold != -1)
                self.label_accuracy(flattened_probs.float().contiguous(),
                                    flattened_gold.squeeze(-1),
                                    mask=evaluation_mask)

                self.all_f1_metrics(flattened_probs,
                                    flattened_gold,
                                    mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(
                        namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs,
                           flattened_gold,
                           mask=evaluation_mask)

        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict
Exemplo n.º 23
0
    def forward(self,  # type: ignore
                text: Dict[str, torch.LongTensor],
                spans: torch.IntTensor,
                span_labels: torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        text : ``Dict[str, torch.LongTensor]``, required.
            The output of a ``TextField`` representing the text of
            the document.
        spans : ``torch.IntTensor``, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a ``ListField[SpanField]`` of
            indices into the text of the document.
        span_labels : ``torch.IntTensor``, optional (default = None)
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.

        Returns
        -------
        An output dictionary consisting of:
        top_spans : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep, 2)`` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : ``torch.IntTensor``
            A tensor of shape ``(num_spans_to_keep, max_antecedents)`` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : ``torch.IntTensor``
            A tensor of shape ``(batch_size, num_spans_to_keep)`` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask,
         top_span_indices, top_span_mention_scores) = self._mention_pruner(span_embeddings,
                                                                           span_mask,
                                                                           num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans,
                                              top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        #  index to the indices of its allowed antecedents. Note that this is independent
        #  of the batch dimension - it's just a function of the span's position in
        # top_spans. The spans are in document order, so we can just use the relative
        # index of the spans to know which other spans are allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        #  like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        #  we can use to make coreference decisions between valid span pairs.

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = \
            self._generate_valid_antecedents(num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(top_span_embeddings,
                                                                      valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(top_span_mention_scores,
                                                                          valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(top_span_embeddings,
                                                                  candidate_antecedent_embeddings,
                                                                  valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(span_pair_embeddings,
                                                              top_span_mention_scores,
                                                              candidate_antecedent_mention_scores,
                                                              valid_antecedent_log_mask)

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {"top_spans": top_spans,
                       "antecedent_indices": valid_antecedent_indices,
                       "predicted_antecedents": predicted_antecedents}
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(span_labels.unsqueeze(-1),
                                                           top_span_indices,
                                                           flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(pruned_gold_labels,
                                                            valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(pruned_gold_labels,
                                                                          antecedent_labels)
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices, predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Exemplo n.º 24
0
    def forward(
        self,  # type: ignore
        text: Dict[str, torch.LongTensor],
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ

        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        if self._use_gold_mentions:
            if text_embeddings.is_cuda:
                device = torch.device("cuda")
            else:
                device = torch.device("cpu")

            s = [
                torch.as_tensor(pair, dtype=torch.long, device=device)
                for cluster in metadata[0]["clusters"] for pair in cluster
            ]
            gm = torch.stack(s, dim=0).unsqueeze(0).unsqueeze(1)

            span_mask = spans.unsqueeze(2) - gm
            span_mask = (span_mask[:, :, :, 0] == 0) + (span_mask[:, :, :, 1]
                                                        == 0)
            span_mask, _ = (span_mask == 2).max(-1)
            num_spans = span_mask.sum().item()
            span_mask = span_mask.float()
        else:
            span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
            num_spans = spans.size(1)
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))

        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Shapes:
        # (num_spans_to_keep, max_antecedents),
        # (1, max_antecedents),
        # (1, num_spans_to_keep, max_antecedents)
        valid_antecedent_indices, valid_antecedent_offsets, valid_antecedent_log_mask = self._generate_valid_antecedents(
            num_spans_to_keep, max_antecedents, util.get_device_of(text_mask))
        # Select tensors relating to the antecedent spans.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        candidate_antecedent_embeddings = util.flattened_index_select(
            top_span_embeddings, valid_antecedent_indices)

        # Shape: (batch_size, num_spans_to_keep, max_antecedents)
        candidate_antecedent_mention_scores = util.flattened_index_select(
            top_span_mention_scores, valid_antecedent_indices).squeeze(-1)
        # Compute antecedent scores.
        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        span_pair_embeddings = self._compute_span_pair_embeddings(
            top_span_embeddings, candidate_antecedent_embeddings,
            valid_antecedent_offsets)
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            span_pair_embeddings,
            top_span_mention_scores,
            candidate_antecedent_mention_scores,
            valid_antecedent_log_mask,
        )

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": valid_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices)

            antecedent_labels = util.flattened_index_select(
                pruned_gold_labels, valid_antecedent_indices).squeeze(-1)
            antecedent_labels += valid_antecedent_log_mask.long()

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels)
            coreference_log_probs = util.last_dim_log_softmax(
                coreference_scores, top_span_mask)
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log(
            )
            negative_marginal_log_likelihood = -util.logsumexp(
                correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(top_spans, valid_antecedent_indices,
                                     predicted_antecedents, metadata)

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
Exemplo n.º 25
0
    def forward(
        self,  # type: ignore
        text: TextFieldTensors,
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters

        text : `TextFieldTensors`, required.
            The output of a `TextField` representing the text of
            the document.
        spans : `torch.IntTensor`, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of
            indices into the text of the document.
        span_labels : `torch.IntTensor`, optional (default = None).
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : `List[Dict[str, Any]]`, optional (default = None).
            A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys
            from this dictionary, which respectively have the original text and the annotated gold coreference
            clusters for that instance.

        # Returns

        An output dictionary consisting of:
        top_spans : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : `torch.IntTensor`
            A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        batch_size = spans.size(0)
        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text)

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)
        ).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep
        )

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices)
        # Shape: (batch_size, num_spans_to_keep, embedding_size)
        top_span_embeddings = util.batched_index_select(
            span_embeddings, top_span_indices, flat_top_span_indices
        )

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        # index to the indices of its allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        # we can use to make coreference decisions between valid span pairs.

        if self._coarse_to_fine:
            pruned_antecedents = self._coarse_to_fine_pruning(
                top_span_embeddings, top_span_mention_scores, top_span_mask, max_antecedents
            )
        else:
            pruned_antecedents = self._distance_pruning(
                top_span_embeddings, top_span_mention_scores, max_antecedents
            )

        # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors
        (
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
            top_antecedent_indices,
        ) = pruned_antecedents

        flat_top_antecedent_indices = util.flatten_and_batch_shift_indices(
            top_antecedent_indices, num_spans_to_keep
        )

        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        top_antecedent_embeddings = util.batched_index_select(
            top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
        )
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            top_span_embeddings,
            top_antecedent_embeddings,
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
        )

        for _ in range(self._inference_order - 1):
            dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1)
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,)
            top_antecedent_with_dummy_mask = torch.cat([dummy_mask, top_antecedent_mask], -1)
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
            attention_weight = util.masked_softmax(
                coreference_scores, top_antecedent_with_dummy_mask, memory_efficient=True
            )
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size)
            top_antecedent_with_dummy_embeddings = torch.cat(
                [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2
            )
            # Shape: (batch_size, num_spans_to_keep, embedding_size)
            attended_embeddings = util.weighted_sum(
                top_antecedent_with_dummy_embeddings, attention_weight
            )
            # Shape: (batch_size, num_spans_to_keep, embedding_size)
            top_span_embeddings = self._span_updating_gated_sum(
                top_span_embeddings, attended_embeddings
            )

            # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
            top_antecedent_embeddings = util.batched_index_select(
                top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
            )
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
            coreference_scores = self._compute_coreference_scores(
                top_span_embeddings,
                top_antecedent_embeddings,
                top_partial_coreference_scores,
                top_antecedent_mask,
                top_antecedent_offsets,
            )

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": top_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices
            )

            # Shape: (batch_size, num_spans_to_keep, max_antecedents)
            antecedent_labels = util.batched_index_select(
                pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices
            ).squeeze(-1)
            antecedent_labels = util.replace_masked_values(
                antecedent_labels, top_antecedent_mask, -100
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels
            )
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask.unsqueeze(-1)
            )
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(
                top_spans, top_antecedent_indices, predicted_antecedents, metadata
            )

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
    def forward(
        self,  # type: ignore
        sentences: torch.LongTensor,
        labels: torch.IntTensor = None,
        confidences: torch.Tensor = None,
        additional_features: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences
        print(sentences)
        sentences_conv = {}
        for key, val in sentences_conv.items():
            sentences_conv[key] = val.cpu().data.numpy().tolist()
        self.track_embedding["Transformation_0"] = {
            "sentences": sentences_conv
        }
        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        self.track_embedding["Transformation_1"] = {
            "size": list(embedded_sentences.size()),
            "dim": embedded_sentences.dim()
        }

        # Kacper: Basically a padding mask for bert
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = list(embedded_sentences.size())

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            # Kacper: This is an important step where we get SEP tokens to later do sentence classification
            # Kacper: We take a location of SEP tokens from the sentences to get a mask
            sentences_mask = sentences[
                'bert'] == 103  # mask for all the SEP tokens in the batch
            # Kacper: We use this mask to get the respective embeddings from the output layer of bert
            embedded_sentences = embedded_sentences[
                sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
            # returns num_sentences_per_batch x vector_len
            self.track_embedding["Transformation_2"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: I dont get it why it became 2 instead of 4? What is the difference between size() and dim()???
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # Kacper: comment below is vague
            # Kacper: I think we batch in one array because we just need to compute a mean loss from all of them
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(
                dim=0)  # Kacper: We batch all sentences in one array
            self.track_embedding["Transformation_3"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: Dropout layer is between filtered embeddings and linear layer
            embedded_sentences = self.dropout(embedded_sentences)
            self.track_embedding["Transformation_4"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: we provide the labels for training (for each sentence)
            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[
                    labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                # Kacper: this might be useful to consider in my code as well
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(
                        f'Found {num_labels} labels but {num_sentences} sentences'
                    )
                    labels = labels[:
                                    num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                    # We are ignoring this problem for now.
                    # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:
                                                                  num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            # Kacper: this shouldnt be the case for our project
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = list(embedded_sentences.size())
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat(
                (embedded_sentences, additional_features), dim=-1)

        # Kacper: we unwrap the time dimension of a tensor into the 1st dimension (batch),
        # Kacper: apply a linear layer and wrap the the time dimension back
        # Kacper: I would suspect it is happening only for embeddings related to the [SEP] tokens
        label_logits = self.time_distributed_aggregate_feedforward(
            embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels
        self.track_embedding["logits"] = {
            "size": list(label_logits.size()),
            "dim": label_logits.dim()
        }
        #print(self.track_embedding)
        self.track_embedding_list.append(deepcopy(self.track_embedding))
        with open(path_json, 'w') as json_out:
            json.dump(self.track_embedding_list, json_out)

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            # Kacper: reshape logits to be of the following shape in view()
            flattened_logits = label_logits.view((batch_size * num_sentences),
                                                 self.num_labels)
            # Make labels to be contiguous in memory, reshape it so it is in a one dimension
            flattened_gold = labels.contiguous().view(
                -1)  # Kacper: True labels

            if not self.with_crf:
                # Kacper: We are only interested in this part of the code since we don't use crf
                # Kacper: Get a loss (MSE if sci_sum is True or Crossentropy)
                label_loss = self.loss(flattened_logits.squeeze(),
                                       flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(
                        label_loss).view(-1)
                label_loss = label_loss.mean()  # Kacper: Get a mean loss
                # Kacper: Get a probabilities from the logits
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                # Kacper: We are not interested in this if statement branch (for our project)
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels,
                                          mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view(
                    (batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                # Kacper: this will be a case for us as well because labels are numerical for Pubmed data
                evaluation_mask = (flattened_gold != -1)
                # Kacper: CategoricalAccuracy is computed in this case
                self.label_accuracy(flattened_probs.float().contiguous(),
                                    flattened_gold.squeeze(-1),
                                    mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(
                        namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs,
                           flattened_gold,
                           mask=evaluation_mask)

        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict