def forward(
        self,  # type: ignore
        text: TextFieldTensors,
        spans: torch.IntTensor,
        span_labels: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        """
        # Parameters

        text : `TextFieldTensors`, required.
            The output of a `TextField` representing the text of
            the document.
        spans : `torch.IntTensor`, required.
            A tensor of shape (batch_size, num_spans, 2), representing the inclusive start and end
            indices of candidate spans for mentions. Comes from a `ListField[SpanField]` of
            indices into the text of the document.
        span_labels : `torch.IntTensor`, optional (default = None).
            A tensor of shape (batch_size, num_spans), representing the cluster ids
            of each span, or -1 for those which do not appear in any clusters.
        metadata : `List[Dict[str, Any]]`, optional (default = None).
            A metadata dictionary for each instance in the batch. We use the "original_text" and "clusters" keys
            from this dictionary, which respectively have the original text and the annotated gold coreference
            clusters for that instance.

        # Returns

        An output dictionary consisting of:
        top_spans : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep, 2)` representing
            the start and end word indices of the top spans that survived the pruning stage.
        antecedent_indices : `torch.IntTensor`
            A tensor of shape `(num_spans_to_keep, max_antecedents)` representing for each top span
            the index (with respect to top_spans) of the possible antecedents the model considered.
        predicted_antecedents : `torch.IntTensor`
            A tensor of shape `(batch_size, num_spans_to_keep)` representing, for each top span, the
            index (with respect to antecedent_indices) of the most likely antecedent. -1 means there
            was no predicted link.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(self._text_field_embedder(text))

        batch_size = spans.size(0)
        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text)

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(text_embeddings, text_mask)
        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)

        # Prune based on mention scores.
        num_spans_to_keep = int(math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)
        ).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep
        )

        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices, flat_top_span_indices)
        # Shape: (batch_size, num_spans_to_keep, embedding_size)
        top_span_embeddings = util.batched_index_select(
            span_embeddings, top_span_indices, flat_top_span_indices
        )

        # Compute indices for antecedent spans to consider.
        max_antecedents = min(self._max_antecedents, num_spans_to_keep)

        # Now that we have our variables in terms of num_spans_to_keep, we need to
        # compare span pairs to decide each span's antecedent. Each span can only
        # have prior spans as antecedents, and we only consider up to max_antecedents
        # prior spans. So the first thing we do is construct a matrix mapping a span's
        # index to the indices of its allowed antecedents.

        # Once we have this matrix, we reformat our variables again to get embeddings
        # for all valid antecedents for each span. This gives us variables with shapes
        # like (batch_size, num_spans_to_keep, max_antecedents, embedding_size), which
        # we can use to make coreference decisions between valid span pairs.

        if self._coarse_to_fine:
            pruned_antecedents = self._coarse_to_fine_pruning(
                top_span_embeddings, top_span_mention_scores, top_span_mask, max_antecedents
            )
        else:
            pruned_antecedents = self._distance_pruning(
                top_span_embeddings, top_span_mention_scores, max_antecedents
            )

        # Shape: (batch_size, num_spans_to_keep, max_antecedents) for all 4 tensors
        (
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
            top_antecedent_indices,
        ) = pruned_antecedents

        flat_top_antecedent_indices = util.flatten_and_batch_shift_indices(
            top_antecedent_indices, num_spans_to_keep
        )

        # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
        top_antecedent_embeddings = util.batched_index_select(
            top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
        )
        # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
        coreference_scores = self._compute_coreference_scores(
            top_span_embeddings,
            top_antecedent_embeddings,
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
        )

        for _ in range(self._inference_order - 1):
            dummy_mask = top_antecedent_mask.new_ones(batch_size, num_spans_to_keep, 1)
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents,)
            top_antecedent_with_dummy_mask = torch.cat([dummy_mask, top_antecedent_mask], -1)
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
            attention_weight = util.masked_softmax(
                coreference_scores, top_antecedent_with_dummy_mask, memory_efficient=True
            )
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents, embedding_size)
            top_antecedent_with_dummy_embeddings = torch.cat(
                [top_span_embeddings.unsqueeze(2), top_antecedent_embeddings], 2
            )
            # Shape: (batch_size, num_spans_to_keep, embedding_size)
            attended_embeddings = util.weighted_sum(
                top_antecedent_with_dummy_embeddings, attention_weight
            )
            # Shape: (batch_size, num_spans_to_keep, embedding_size)
            top_span_embeddings = self._span_updating_gated_sum(
                top_span_embeddings, attended_embeddings
            )

            # Shape: (batch_size, num_spans_to_keep, max_antecedents, embedding_size)
            top_antecedent_embeddings = util.batched_index_select(
                top_span_embeddings, top_antecedent_indices, flat_top_antecedent_indices
            )
            # Shape: (batch_size, num_spans_to_keep, 1 + max_antecedents)
            coreference_scores = self._compute_coreference_scores(
                top_span_embeddings,
                top_antecedent_embeddings,
                top_partial_coreference_scores,
                top_antecedent_mask,
                top_antecedent_offsets,
            )

        # We now have, for each span which survived the pruning stage,
        # a predicted antecedent. This implies a clustering if we group
        # mentions which refer to each other in a chain.
        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_antecedents = coreference_scores.max(2)
        # Subtract one here because index 0 is the "no antecedent" class,
        # so this makes the indices line up with actual spans if the prediction
        # is greater than -1.
        predicted_antecedents -= 1

        output_dict = {
            "top_spans": top_spans,
            "antecedent_indices": top_antecedent_indices,
            "predicted_antecedents": predicted_antecedents,
        }
        if span_labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                span_labels.unsqueeze(-1), top_span_indices, flat_top_span_indices
            )

            # Shape: (batch_size, num_spans_to_keep, max_antecedents)
            antecedent_labels = util.batched_index_select(
                pruned_gold_labels, top_antecedent_indices, flat_top_antecedent_indices
            ).squeeze(-1)
            antecedent_labels = util.replace_masked_values(
                antecedent_labels, top_antecedent_mask, -100
            )

            # Compute labels.
            # Shape: (batch_size, num_spans_to_keep, max_antecedents + 1)
            gold_antecedent_labels = self._compute_antecedent_gold_labels(
                pruned_gold_labels, antecedent_labels
            )
            # Now, compute the loss using the negative marginal log-likelihood.
            # This is equal to the log of the sum of the probabilities of all antecedent predictions
            # that would be consistent with the data, in the sense that we are minimising, for a
            # given span, the negative marginal log likelihood of all antecedents which are in the
            # same gold cluster as the span we are currently considering. Each span i predicts a
            # single antecedent j, but there might be several prior mentions k in the same
            # coreference cluster that would be valid antecedents. Our loss is the sum of the
            # probability assigned to all valid antecedents. This is a valid objective for
            # clustering as we don't mind which antecedent is predicted, so long as they are in
            #  the same coreference cluster.
            coreference_log_probs = util.masked_log_softmax(
                coreference_scores, top_span_mask.unsqueeze(-1)
            )
            correct_antecedent_log_probs = coreference_log_probs + gold_antecedent_labels.log()
            negative_marginal_log_likelihood = -util.logsumexp(correct_antecedent_log_probs).sum()

            self._mention_recall(top_spans, metadata)
            self._conll_coref_scores(
                top_spans, top_antecedent_indices, predicted_antecedents, metadata
            )

            output_dict["loss"] = negative_marginal_log_likelihood

        if metadata is not None:
            output_dict["document"] = [x["original_text"] for x in metadata]
        return output_dict
    def _coarse_to_fine_pruning(
        self,
        top_span_embeddings: torch.FloatTensor,
        top_span_mention_scores: torch.FloatTensor,
        top_span_mask: torch.BoolTensor,
        max_antecedents: int,
    ) -> Tuple[torch.FloatTensor, torch.BoolTensor, torch.LongTensor, torch.LongTensor]:
        """
        Generates antecedents for each span and prunes down to `max_antecedents`. This method
        prunes antecedents using a fast bilinar interaction score between a span and a candidate
        antecedent, and the highest-scoring antecedents are kept.

        # Parameters

        top_span_embeddings: torch.FloatTensor, required.
            The embeddings of the top spans.
            (batch_size, num_spans_to_keep, embedding_size).
        top_span_mention_scores: torch.FloatTensor, required.
            The mention scores of the top spans.
            (batch_size, num_spans_to_keep).
        top_span_mask: torch.BoolTensor, required.
            The mask for the top spans.
            (batch_size, num_spans_to_keep).
        max_antecedents: int, required.
            The maximum number of antecedents to keep for each span.

        # Returns

        top_partial_coreference_scores: torch.FloatTensor
            The partial antecedent scores for each span-antecedent pair. Computed by summing
            the span mentions scores of the span and the antecedent as well as a bilinear
            interaction term. This score is partial because compared to the full coreference scores,
            it lacks the interaction term
            w * FFNN([g_i, g_j, g_i * g_j, features]).
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_mask: torch.BoolTensor
            The mask representing whether each antecedent span is valid. Required since
            different spans have different numbers of valid antecedents. For example, the first
            span in the document should have no valid antecedents.
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_offsets: torch.LongTensor
            The distance between the span and each of its antecedents in terms of the number
            of considered spans (i.e not the word distance between the spans).
            (batch_size, num_spans_to_keep, max_antecedents)
        top_antecedent_indices: torch.LongTensor
            The indices of every antecedent to consider with respect to the top k spans.
            (batch_size, num_spans_to_keep, max_antecedents)
        """
        batch_size, num_spans_to_keep = top_span_embeddings.size()[:2]
        device = util.get_device_of(top_span_embeddings)

        # Shape: (1, num_spans_to_keep, num_spans_to_keep)
        _, _, valid_antecedent_mask = self._generate_valid_antecedents(
            num_spans_to_keep, num_spans_to_keep, device
        )

        mention_one_score = top_span_mention_scores.unsqueeze(1)
        mention_two_score = top_span_mention_scores.unsqueeze(2)
        bilinear_weights = self._coarse2fine_scorer(top_span_embeddings).transpose(1, 2)
        bilinear_score = torch.matmul(top_span_embeddings, bilinear_weights)
        # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
        partial_antecedent_scores = mention_one_score + mention_two_score + bilinear_score

        # Shape: (batch_size, num_spans_to_keep, num_spans_to_keep); broadcast op
        span_pair_mask = top_span_mask.unsqueeze(-1) & valid_antecedent_mask

        # Shape:
        # (batch_size, num_spans_to_keep, max_antecedents) * 3
        (
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_indices,
        ) = util.masked_topk(partial_antecedent_scores, span_pair_mask, max_antecedents)

        top_span_range = util.get_range_vector(num_spans_to_keep, device)
        # Shape: (num_spans_to_keep, num_spans_to_keep); broadcast op
        valid_antecedent_offsets = top_span_range.unsqueeze(-1) - top_span_range.unsqueeze(0)

        # TODO: we need to make `batched_index_select` more general to make this less awkward.
        top_antecedent_offsets = util.batched_index_select(
            valid_antecedent_offsets.unsqueeze(0)
            .expand(batch_size, num_spans_to_keep, num_spans_to_keep)
            .reshape(batch_size * num_spans_to_keep, num_spans_to_keep, 1),
            top_antecedent_indices.view(-1, max_antecedents),
        ).reshape(batch_size, num_spans_to_keep, max_antecedents)

        return (
            top_partial_coreference_scores,
            top_antecedent_mask,
            top_antecedent_offsets,
            top_antecedent_indices,
        )
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        sentence_end: torch.LongTensor,
        spans: torch.LongTensor,
        span_labels: torch.LongTensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        start = time.time()
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            # token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            bert_embeddings, spans)

        if self._context_layer is not None:
            contextualized_embeddings = self._context_layer(
                embedded_text_input, mask)
            # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
            endpoint_span_embeddings = self._endpoint_span_extractor(
                contextualized_embeddings, spans)

            # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
            # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)
            span_embeddings = endpoint_span_embeddings
        else:
            span_embeddings = attended_span_embeddings

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * sequence_length))
        num_spans = spans.shape[1]
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep)
        verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat(
            1, 1, embedded_text_input.shape[-1])
        verb_embeddings = torch.gather(embedded_text_input, 1, verb_index)
        assert len(
            verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1
        verb_embeddings = verb_embeddings.squeeze(1)
        # print(verb_indicator.sum(1, keepdim=True) > 0)
        verb_embeddings = torch.where(
            (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                1, verb_embeddings.shape[-1]), verb_embeddings,
            torch.zeros_like(verb_embeddings))
        # print(verb_embeddings)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, spans.shape[1])
        span_embeddings = util.batched_index_select(span_embeddings,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)
        top_span_labels = util.batched_index_select(
            span_labels.unsqueeze(-1), top_span_indices,
            flat_top_span_indices).squeeze(-1)
        concatenated_span_embeddings = torch.cat(
            (span_embeddings, verb_embeddings.unsqueeze(1).repeat(
                1, span_embeddings.shape[1], 1)),
            dim=2)
        # print(concatenated_span_embeddings[:,:,:])
        hidden = self.hidden_layer(concatenated_span_embeddings)
        # print(hidden[1,:,:])
        # print(top_span_indices)
        # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])])
        # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels"))
        predictions = self.output_layer(hidden)
        # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)
        predictions = torch.cat(
            (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1)
        # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1))

        output_dict = {}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]),
                                  top_span_labels.view(-1)) *
                    top_span_mask.float().view(-1)
                    ).sum() / top_span_mask.float().sum()
            # print(top_span_labels)
            # print(predictions.argmax(-1))
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.get_tags(
                    top_spans, predictions, mask.shape[1], top_span_mask,
                    output_dict)
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print('G', batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict