Exemplo n.º 1
0
    def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = torch.randn([2, 5, 7])
        # Concatentate start and end points together to form our representation.
        extractor = EndpointSpanExtractor(7, "x,y")

        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])
        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 14]
        assert extractor.get_output_dim() == 14
        assert extractor.get_input_dim() == 7

        start_indices, end_indices = indices.split(1, -1)
        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        start_embeddings, end_embeddings = span_representations.split(7, -1)

        correct_start_embeddings = batched_index_select(
            sequence_tensor, start_indices.squeeze())
        correct_end_embeddings = batched_index_select(sequence_tensor,
                                                      end_indices.squeeze())
        numpy.testing.assert_array_equal(start_embeddings.data.numpy(),
                                         correct_start_embeddings.data.numpy())
        numpy.testing.assert_array_equal(end_embeddings.data.numpy(),
                                         correct_end_embeddings.data.numpy())
Exemplo n.º 2
0
class DyGIE(Model):
    """
    TODO(dwadden) document me.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``text`` ``TextField`` we get as input to the model.
    context_layer : ``Seq2SeqEncoder``
        This layer incorporates contextual information for each word in the document.
    feature_size: ``int``
        The embedding size for all the embedded features, such as distances or span widths.
    submodule_params: ``TODO(dwadden)``
        A nested dictionary specifying parameters to be passed on to initialize submodules.
    max_span_width: ``int``
        The maximum width of candidate spans.
    target_task: ``str``:
        The task used to make early stopping decisions.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    module_initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the individual modules.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    display_metrics: ``List[str]``. A list of the metrics that should be printed out during model
        training.
    """

    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 modules,  # TODO(dwadden) Add type.
                 feature_size: int,
                 max_span_width: int,
                 target_task: str,
                 feedforward_params: Dict[str, Union[int, float]],
                 loss_weights: Dict[str, float],
                 initializer: InitializerApplicator = InitializerApplicator(),
                 module_initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 display_metrics: List[str] = None) -> None:
        super(DyGIE, self).__init__(vocab, regularizer)

        ####################

        # Create span extractor.
        self._endpoint_span_extractor = EndpointSpanExtractor(
            embedder.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)

        ####################

        # Set parameters.
        self._embedder = embedder
        self._loss_weights = loss_weights
        self._max_span_width = max_span_width
        self._display_metrics = self._get_display_metrics(target_task)
        token_emb_dim = self._embedder.get_output_dim()
        span_emb_dim = self._endpoint_span_extractor.get_output_dim()

        ####################

        # Create submodules.

        modules = Params(modules)

        # Helper function to create feedforward networks.
        def make_feedforward(input_dim):
            return FeedForward(input_dim=input_dim,
                               num_layers=feedforward_params["num_layers"],
                               hidden_dims=feedforward_params["hidden_dims"],
                               activations=torch.nn.ReLU(),
                               dropout=feedforward_params["dropout"])

        # Submodules

        self._ner = NERTagger.from_params(vocab=vocab,
                                          make_feedforward=make_feedforward,
                                          span_emb_dim=span_emb_dim,
                                          feature_size=feature_size,
                                          params=modules.pop("ner"))

        self._coref = CorefResolver.from_params(vocab=vocab,
                                                make_feedforward=make_feedforward,
                                                span_emb_dim=span_emb_dim,
                                                feature_size=feature_size,
                                                params=modules.pop("coref"))

        self._relation = RelationExtractor.from_params(vocab=vocab,
                                                       make_feedforward=make_feedforward,
                                                       span_emb_dim=span_emb_dim,
                                                       feature_size=feature_size,
                                                       params=modules.pop("relation"))

        self._events = EventExtractor.from_params(vocab=vocab,
                                                  make_feedforward=make_feedforward,
                                                  token_emb_dim=token_emb_dim,
                                                  span_emb_dim=span_emb_dim,
                                                  feature_size=feature_size,
                                                  params=modules.pop("events"))

        ####################

        # Initialize text embedder and all submodules
        for module in [self._ner, self._coref, self._relation, self._events]:
            module_initializer(module)

        initializer(self)

    @staticmethod
    def _get_display_metrics(target_task):
        """
        The `target` is the name of the task used to make early stopping decisions. Show metrics
        related to this task.
        """
        lookup = {
            "ner": [f"MEAN__{name}" for name in
                    ["ner_precision", "ner_recall", "ner_f1"]],
            "relation": [f"MEAN__{name}" for name in
                         ["relation_precision", "relation_recall", "relation_f1"]],
            "coref": ["coref_precision", "coref_recall", "coref_f1", "coref_mention_recall"],
            "events": [f"MEAN__{name}" for name in
                       ["trig_class_f1", "arg_class_f1"]]}
        if target_task not in lookup:
            raise ValueError(f"Invalied value {target_task} has been given as the target task.")
        return lookup[target_task]

    @staticmethod
    def _debatch(x):
        # TODO(dwadden) Get rid of this when I find a better way to do it.
        return x if x is None else x.squeeze(0)

    @overrides
    def forward(self,
                text,
                spans,
                metadata,
                ner_labels=None,
                coref_labels=None,
                relation_labels=None,
                trigger_labels=None,
                argument_labels=None):
        """
        TODO(dwadden) change this.
        """
        # In AllenNLP, AdjacencyFields are passed in as floats. This fixes it.
        if relation_labels is not None:
            relation_labels = relation_labels.long()
        if argument_labels is not None:
            argument_labels = argument_labels.long()

        # TODO(dwadden) Multi-document minibatching isn't supported yet. For now, get rid of the
        # extra dimension in the input tensors. Will return to this once the model runs.
        if len(metadata) > 1:
            raise NotImplementedError("Multi-document minibatching not supported.")

        metadata = metadata[0]
        spans = self._debatch(spans)  # (n_sents, max_n_spans, 2)
        ner_labels = self._debatch(ner_labels)  # (n_sents, max_n_spans)
        coref_labels = self._debatch(coref_labels)  #  (n_sents, max_n_spans)
        relation_labels = self._debatch(relation_labels)  # (n_sents, max_n_spans, max_n_spans)
        trigger_labels = self._debatch(trigger_labels)  # TODO(dwadden)
        argument_labels = self._debatch(argument_labels)  # TODO(dwadden)

        # Encode using BERT, then debatch.
        # Since the data are batched, we use `num_wrapping_dims=1` to unwrap the document dimension.
        # (1, n_sents, max_sententence_length, embedding_dim)

        # TODO(dwadden) Deal with the case where the input is longer than 512.
        text_embeddings = self._embedder(text, num_wrapping_dims=1)
        # (n_sents, max_n_wordpieces, embedding_dim)
        text_embeddings = self._debatch(text_embeddings)

        # (n_sents, max_sentence_length)
        text_mask = self._debatch(util.get_text_field_mask(text, num_wrapping_dims=1).float())
        sentence_lengths = text_mask.sum(dim=1).long()  # (n_sents)

        span_mask = (spans[:, :, 0] >= 0).float()  # (n_sents, max_n_spans)
        # SpanFields return -1 when they are used as padding. As we do some comparisons based on
        # span widths when we attend over the span representations that we generate from these
        # indices, we need them to be <= 0. This is only relevant in edge cases where the number of
        # spans we consider after the pruning stage is >= the total number of spans, because in this
        # case, it is possible we might consider a masked span.
        spans = F.relu(spans.float()).long()  # (n_sents, max_n_spans, 2)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        span_embeddings = self._endpoint_span_extractor(text_embeddings, spans)

        # Make calls out to the modules to get results.
        output_coref = {'loss': 0}
        output_ner = {'loss': 0}
        output_relation = {'loss': 0}
        output_events = {'loss': 0}

        # Prune and compute span representations for coreference module
        if self._loss_weights["coref"] > 0 or self._coref.coref_prop > 0:
            output_coref, coref_indices = self._coref.compute_representations(
                spans, span_mask, span_embeddings, sentence_lengths, coref_labels, metadata)

        # Propagation of global information to enhance the span embeddings
        if self._coref.coref_prop > 0:
            output_coref = self._coref.coref_propagation(output_coref)
            span_embeddings = self._coref.update_spans(
                output_coref, span_embeddings, coref_indices)

        # Make predictions and compute losses for each module
        if self._loss_weights['ner'] > 0:
            output_ner = self._ner(
                spans, span_mask, span_embeddings, sentence_lengths, ner_labels, metadata)

        if self._loss_weights['coref'] > 0:
            output_coref = self._coref.predict_labels(output_coref, metadata)

        if self._loss_weights['relation'] > 0:
            output_relation = self._relation(
                spans, span_mask, span_embeddings, sentence_lengths, relation_labels, metadata)

        if self._loss_weights['events'] > 0:
            # The `text_embeddings` serve as representations for event triggers.
            output_events = self._events(
                text_mask, text_embeddings, spans, span_mask, span_embeddings,
                sentence_lengths, trigger_labels, argument_labels,
                ner_labels, metadata)

        # Use `get` since there are some cases where the output dict won't have a loss - for
        # instance, when doing prediction.
        loss = (self._loss_weights['coref'] * output_coref.get("loss", 0) +
                self._loss_weights['ner'] * output_ner.get("loss", 0) +
                self._loss_weights['relation'] * output_relation.get("loss", 0) +
                self._loss_weights['events'] * output_events.get("loss", 0))

        # Multiply the loss by the weight multiplier for this document.
        weight = metadata.weight if metadata.weight is not None else 1.0
        loss *= torch.tensor(weight)

        output_dict = dict(coref=output_coref,
                           relation=output_relation,
                           ner=output_ner,
                           events=output_events)
        output_dict['loss'] = loss

        output_dict["metadata"] = metadata

        return output_dict

    def update_span_embeddings(self, span_embeddings, span_mask, top_span_embeddings,
                               top_span_mask, top_span_indices):
        # TODO(Ulme) Speed this up by tensorizing

        new_span_embeddings = span_embeddings.clone()
        for sample_nr in range(len(top_span_mask)):
            for top_span_nr, span_nr in enumerate(top_span_indices[sample_nr]):
                if top_span_mask[sample_nr, top_span_nr] == 0 or span_mask[sample_nr, span_nr] == 0:
                    break
                new_span_embeddings[sample_nr,
                                    span_nr] = top_span_embeddings[sample_nr, top_span_nr]
        return new_span_embeddings

    @overrides
    def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]):
        """
        Converts the list of spans and predicted antecedent indices into clusters
        of spans for each element in the batch.

        Parameters
        ----------
        output_dict : ``Dict[str, torch.Tensor]``, required.
            The result of calling :func:`forward` on an instance or batch of instances.

        Returns
        -------
        The same output dictionary, but with an additional ``clusters`` key:

        clusters : ``List[List[List[Tuple[int, int]]]]``
            A nested list, representing, for each instance in the batch, the list of clusters,
            which are in turn comprised of a list of (start, end) inclusive spans into the
            original document.
        """

        doc = copy.deepcopy(output_dict["metadata"])

        if self._loss_weights["coref"] > 0:
            # TODO(dwadden) Will need to get rid of the [0] when batch training is enabled.
            decoded_coref = self._coref.make_output_human_readable(output_dict["coref"])["predicted_clusters"][0]
            sentences = doc.sentences
            sentence_starts = [sent.sentence_start for sent in sentences]
            predicted_clusters = [document.Cluster(entry, i, sentences, sentence_starts)
                                  for i, entry in enumerate(decoded_coref)]
            doc.predicted_clusters = predicted_clusters
            # TODO(dwadden) update the sentences with cluster information.

        if self._loss_weights["ner"] > 0:
            for predictions, sentence in zip(output_dict["ner"]["predictions"], doc):
                sentence.predicted_ner = predictions

        if self._loss_weights["relation"] > 0:
            for predictions, sentence in zip(output_dict["relation"]["predictions"], doc):
                sentence.predicted_relations = predictions

        if self._loss_weights["events"] > 0:
            for predictions, sentence in zip(output_dict["events"]["predictions"], doc):
                sentence.predicted_events = predictions

        return doc

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        Get all metrics from all modules. For the ones that shouldn't be displayed, prefix their
        keys with an underscore.
        """
        metrics_coref = self._coref.get_metrics(reset=reset)
        metrics_ner = self._ner.get_metrics(reset=reset)
        metrics_relation = self._relation.get_metrics(reset=reset)
        metrics_events = self._events.get_metrics(reset=reset)

        # Make sure that there aren't any conflicting names.
        metric_names = (list(metrics_coref.keys()) + list(metrics_ner.keys()) +
                        list(metrics_relation.keys()) + list(metrics_events.keys()))
        assert len(set(metric_names)) == len(metric_names)
        all_metrics = dict(list(metrics_coref.items()) +
                           list(metrics_ner.items()) +
                           list(metrics_relation.items()) +
                           list(metrics_events.items()))

        # If no list of desired metrics given, display them all.
        if self._display_metrics is None:
            return all_metrics
        # Otherwise only display the selected ones.
        res = {}
        for k, v in all_metrics.items():
            if k in self._display_metrics:
                res[k] = v
            else:
                new_k = "_" + k
                res[new_k] = v
        return res
Exemplo n.º 3
0
class TweetJointly(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        transformer_model_name: str = "bert-base-uncased",
        feedforward: Optional[FeedForward] = None,
        smoothing: bool = False,
        smooth_alpha: float = 0.7,
        sentiment_task: bool = False,
        sentiment_task_weight: float = 1.0,
        sentiment_classification_with_label: bool = True,
        sentiment_seq2vec: Optional[Seq2VecEncoder] = None,
        candidate_span_task: bool = False,
        candidate_span_task_weight: float = 1.0,
        candidate_delay: int = 30000,
        candidate_span_num: int = 5,
        candidate_classification_layer_units: int = 128,
        candidate_span_extractor: Optional[SpanExtractor] = None,
        candidate_span_with_logits: bool = False,
        dropout: Optional[float] = None,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)
        if "BERTweet" not in transformer_model_name:
            self._text_field_embedder = BasicTextFieldEmbedder({
                "tokens":
                PretrainedTransformerEmbedder(transformer_model_name)
            })
        else:
            self._text_field_embedder = BasicTextFieldEmbedder(
                {"tokens": TweetBertEmbedder(transformer_model_name)})
        # span start & end task
        if feedforward is None:
            self._linear_layer = nn.Sequential(
                nn.Linear(self._text_field_embedder.get_output_dim(), 128),
                nn.ReLU(),
                nn.Linear(128, 2),
            )
        else:
            self._linear_layer = feedforward
        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._jaccard = Jaccard()
        self._candidate_delay = candidate_delay
        self._delay = 0

        self._smoothing = smoothing
        self._smooth_alpha = smooth_alpha
        if smoothing:
            self._loss = nn.KLDivLoss(reduction="batchmean")
        else:
            self._loss = nn.CrossEntropyLoss()

        # sentiment task
        self._sentiment_task = sentiment_task
        if self._sentiment_task:
            self._sentiment_classification_accuracy = CategoricalAccuracy()
            self._sentiment_loss_log = LossLog()
            self.register_buffer("sentiment_task_weight",
                                 torch.tensor(sentiment_task_weight))
            self._sentiment_classification_with_label = (
                sentiment_classification_with_label)
            if sentiment_seq2vec is None:
                raise ConfigurationError(
                    "sentiment task is True, we need a sentiment seq2vec encoder"
                )
            else:
                self._sentiment_encoder = sentiment_seq2vec
                self._sentiment_linear = nn.Linear(
                    self._sentiment_encoder.get_output_dim(),
                    vocab.get_vocab_size("labels"),
                )

        # candidate span task
        self._candidate_span_task = candidate_span_task
        if candidate_span_task:
            assert candidate_span_num > 0
            assert candidate_span_task_weight > 0
            assert candidate_classification_layer_units > 0
            self._candidate_span_num = candidate_span_num
            self.register_buffer("candidate_span_task_weight",
                                 torch.tensor(candidate_span_task_weight))
            self._candidate_classification_layer_units = (
                candidate_classification_layer_units)
            self._span_classification_accuracy = CategoricalAccuracy()
            self._candidate_loss_log = LossLog()
            self._candidate_span_linear = nn.Linear(
                self._text_field_embedder.get_output_dim(),
                self._candidate_classification_layer_units,
            )

            if candidate_span_extractor is None:
                self._candidate_span_extractor = EndpointSpanExtractor(
                    input_dim=self._candidate_classification_layer_units)
            else:
                self._candidate_span_extractor = candidate_span_extractor

            if candidate_span_with_logits:
                self._candidate_with_logits = True
                self._candidate_span_vec_linear = nn.Linear(
                    self._candidate_span_extractor.get_output_dim() + 1, 1)
            else:
                self._candidate_with_logits = False
                self._candidate_span_vec_linear = nn.Linear(
                    self._candidate_span_extractor.get_output_dim(), 1)

            self._candidate_jaccard = Jaccard()

        if sentiment_task or candidate_span_task:
            self._base_loss_log = LossLog()
        else:
            self._base_loss_log = None

        if dropout is not None:
            self._dropout = nn.Dropout(dropout)
        else:
            self._dropout = None

    def forward(  # type: ignore
        self,
        text: Dict[str, Dict[str, torch.LongTensor]],
        sentiment: torch.IntTensor,
        text_with_sentiment: Dict[str, Dict[str, torch.LongTensor]],
        text_span: torch.IntTensor,
        selected_text_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        # batch_size * text_length * hidden_dims
        embedded_question = self._text_field_embedder(text_with_sentiment)
        if self._dropout is not None:
            embedded_question = self._dropout(embedded_question)
        self._delay += int(embedded_question.size(0))
        # span start & span end task
        logits = self._linear_layer(embedded_question)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        span_start_logits = span_start_logits.squeeze(-1)
        span_end_logits = span_end_logits.squeeze(-1)

        possible_answer_mask = torch.zeros_like(
            util.get_token_ids_from_text_field_tensors(
                text_with_sentiment)).bool()
        for i, (start, end) in enumerate(text_span):
            possible_answer_mask[i, start:end + 1] = True

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       possible_answer_mask,
                                                       -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     possible_answer_mask,
                                                     -1e32)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
        best_spans = get_best_span(span_start_logits, span_end_logits)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_spans,
            "best_span_scores": best_span_scores,
        }

        loss = torch.tensor(0.0).to(embedded_question.device)
        # sentiment task
        if self._sentiment_task:
            if self._sentiment_classification_with_label:
                global_context_vec = self._sentiment_encoder(embedded_question)
            else:
                embedded_only_text = self._text_field_embedder(text)
                if self._dropout is not None:
                    embedded_only_text = self._dropout(embedded_only_text)
                global_context_vec = self._sentiment_encoder(
                    embedded_only_text)
            sentiment_logits = self._sentiment_linear(global_context_vec)
            sentiment_probs = torch.softmax(sentiment_logits, dim=-1)

            self._sentiment_classification_accuracy(sentiment_probs, sentiment)
            sentiment_loss = cross_entropy(sentiment_logits, sentiment)
            self._sentiment_loss_log(sentiment_loss)
            loss.add_(self.sentiment_task_weight * sentiment_loss)

            predict_sentiment_idx = sentiment_probs.argmax(dim=-1)
            sentiment_predicts = []
            for i in predict_sentiment_idx.tolist():
                sentiment_predicts.append(
                    self.vocab.get_token_from_index(i, "labels"))
            output_dict["sentiment_logits"] = sentiment_logits
            output_dict["sentiment_probs"] = sentiment_probs
            output_dict["sentiment_predicts"] = sentiment_predicts

        # span classification
        if self._candidate_span_task and (self._delay >=
                                          self._candidate_delay):
            # shape: (batch_size, passage_length, embedding_dim)
            text_features_for_candidate = self._candidate_span_linear(
                embedded_question)
            text_features_for_candidate = torch.relu(
                text_features_for_candidate)
            with torch.no_grad():
                # batch_size * candidate_num * 2
                candidate_span = get_candidate_span(span_start_probs,
                                                    span_end_probs,
                                                    self._candidate_span_num)
                candidate_span_list = candidate_span.tolist()
                output_dict["candidate_spans"] = candidate_span_list
            if selected_text_span is not None:
                candidate_span, candidate_span_label = self.candidate_span_with_labels(
                    candidate_span, selected_text_span)
            else:
                candidate_span_label = None
            # shape: (batch_size, candidate_num, span_extractor_output_dim)
            span_feature_vec = self._candidate_span_extractor(
                text_features_for_candidate, candidate_span)

            if self._candidate_with_logits:
                candidate_span_start_logits = torch.gather(
                    span_start_logits, 1, candidate_span[:, :, 0])
                candidate_span_end_logits = torch.gather(
                    span_end_logits, 1, candidate_span[:, :, 1])
                candidate_span_sum_logits = (candidate_span_start_logits +
                                             candidate_span_end_logits)
                span_feature_vec = torch.cat(
                    (span_feature_vec, candidate_span_sum_logits.unsqueeze(2)),
                    -1)
            # batch_size * candidate_num
            span_classification_logits = self._candidate_span_vec_linear(
                span_feature_vec).squeeze()
            span_classification_probs = torch.softmax(
                span_classification_logits, -1)
            output_dict[
                "span_classification_probs"] = span_classification_probs
            candidate_best_span_idx = span_classification_probs.argmax(dim=-1)
            view_idx = (
                candidate_best_span_idx +
                torch.arange(0, end=candidate_best_span_idx.shape[0]).to(
                    candidate_best_span_idx.device) * self._candidate_span_num)
            candidate_span_view = candidate_span.view(-1, 2)
            candidate_best_spans = candidate_span_view.index_select(
                0, view_idx)
            output_dict["candidate_best_spans"] = candidate_best_spans.tolist()

            if selected_text_span is not None:
                self._span_classification_accuracy(span_classification_probs,
                                                   candidate_span_label)
                candidate_span_loss = cross_entropy(span_classification_logits,
                                                    candidate_span_label)
                self._candidate_loss_log(candidate_span_loss)
                weighted_loss = self.candidate_span_task_weight * candidate_span_loss
                if candidate_span_loss > 1e2:
                    print(f"candidate loss: {candidate_span_loss}")
                    print(
                        f"span_classification_logits: {span_classification_logits}"
                    )
                    print(f"candidate_span_label: {candidate_span_label}")
                loss.add_(weighted_loss)

            candidate_best_spans = candidate_best_spans.detach().cpu().numpy()
            output_dict["best_candidate_span_str"] = []
            for metadata_entry, best_span in zip(metadata,
                                                 candidate_best_spans):
                text_with_sentiment_tokens = metadata_entry[
                    "text_with_sentiment_tokens"]
                predicted_start, predicted_end = tuple(best_span)
                if predicted_end >= len(text_with_sentiment_tokens):
                    predicted_end = len(text_with_sentiment_tokens) - 1
                best_span_string = self.span_tokens_to_text(
                    metadata_entry["text"],
                    text_with_sentiment_tokens,
                    predicted_start,
                    predicted_end,
                )
                output_dict["best_candidate_span_str"].append(best_span_string)
                answers = metadata_entry.get("selected_text", "")
                if len(answers) > 0:
                    self._candidate_jaccard(best_span_string, answers)

        # Compute the loss for training.
        if selected_text_span is not None:
            span_start = selected_text_span[:, 0]
            span_end = selected_text_span[:, 1]
            span_mask = span_start != -1
            self._span_accuracy(
                best_spans,
                selected_text_span,
                span_mask.unsqueeze(-1).expand_as(best_spans),
            )
            if not self._smoothing:
                start_loss = cross_entropy(span_start_logits,
                                           span_start,
                                           ignore_index=-1)
                if torch.any(start_loss > 1e9):
                    logger.critical("Start loss too high (%r)", start_loss)
                    logger.critical("span_start_logits: %r", span_start_logits)
                    logger.critical("span_start: %r", span_start)
                    logger.critical("text_with_sentiment: %r",
                                    text_with_sentiment)
                    assert False

                end_loss = cross_entropy(span_end_logits,
                                         span_end,
                                         ignore_index=-1)
                if torch.any(end_loss > 1e9):
                    logger.critical("End loss too high (%r)", end_loss)
                    logger.critical("span_end_logits: %r", span_end_logits)
                    logger.critical("span_end: %r", span_end)
                    assert False
            else:
                sequence_length = span_start_logits.size(1)
                device = span_start.device
                start_distance = get_sequence_distance_from_span_endpoint(
                    sequence_length, span_start)
                start_smooth_probs = torch.exp(
                    start_distance *
                    torch.log(torch.tensor(self._smooth_alpha).to(device)))
                start_smooth_probs = start_smooth_probs * possible_answer_mask
                start_smooth_probs = start_smooth_probs / start_smooth_probs.sum(
                    -1, keepdim=True)
                span_start_log_probs = span_start_logits - torch.log(
                    torch.exp(span_start_logits).sum(-1)).unsqueeze(-1)
                end_distance = get_sequence_distance_from_span_endpoint(
                    sequence_length, span_end)
                end_smooth_probs = torch.exp(
                    end_distance *
                    torch.log(torch.tensor(self._smooth_alpha).to(device)))
                end_smooth_probs = end_smooth_probs * possible_answer_mask
                end_smooth_probs = end_smooth_probs / end_smooth_probs.sum(
                    -1, keepdim=True)
                span_end_log_probs = span_end_logits - torch.log(
                    torch.exp(span_end_logits).sum(-1)).unsqueeze(-1)
                # print(end_smooth_probs)
                # print(start_smooth_probs)
                # print(span_end_log_probs)
                # print(span_start_log_probs)
                start_loss = self._loss(span_start_log_probs,
                                        start_smooth_probs)
                end_loss = self._loss(span_end_log_probs, end_smooth_probs)

            span_start_end_loss = (start_loss + end_loss) / 2
            if self._base_loss_log is not None:
                self._base_loss_log(span_start_end_loss)
            loss.add_(span_start_end_loss)
            self._span_start_accuracy(span_start_logits, span_start, span_mask)
            self._span_end_accuracy(span_end_logits, span_end, span_mask)

            output_dict["loss"] = loss

        # compute best span jaccard
        best_spans = best_spans.detach().cpu().numpy()
        output_dict["best_span_str"] = []

        for metadata_entry, best_span in zip(metadata, best_spans):
            text_with_sentiment_tokens = metadata_entry[
                "text_with_sentiment_tokens"]

            predicted_start, predicted_end = tuple(best_span)
            best_span_string = self.span_tokens_to_text(
                metadata_entry["text"],
                text_with_sentiment_tokens,
                predicted_start,
                predicted_end,
            )
            output_dict["best_span_str"].append(best_span_string)

            answers = metadata_entry.get("selected_text", "")
            if len(answers) > 0:
                self._jaccard(best_span_string, answers)

        return output_dict

    # @staticmethod
    # def candidate_span_with_labels(
    #     candidate_span: torch.Tensor, selected_text_span: torch.Tensor
    # ) -> Tuple[torch.Tensor, torch.Tensor]:
    #     correct_span_idx = (candidate_span == selected_text_span.unsqueeze(1)).prod(-1)
    #     candidate_span_adjust = torch.where(
    #         ~(correct_span_idx.unsqueeze(-1) == 1),
    #         candidate_span,
    #         selected_text_span.unsqueeze(1),
    #     )
    #     candidate_span_label = correct_span_idx.argmax(-1)
    #     return candidate_span_adjust, candidate_span_label

    @staticmethod
    def candidate_span_with_labels(
            candidate_span: torch.Tensor, selected_text_span: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        candidate_span_label = batch_span_jaccard(
            candidate_span, selected_text_span).max(-1).indices
        return candidate_span, candidate_span_label

    @staticmethod
    def get_candidate_span_mask(candidate_span: torch.Tensor,
                                passage_length: int) -> torch.Tensor:
        device = candidate_span.device
        batch_size, candidate_num = candidate_span.size()[:-1]
        candidate_span_mask = torch.zeros(batch_size, candidate_num,
                                          passage_length).to(device)
        for i in range(batch_size):
            for j in range(candidate_num):
                span_start, span_end = candidate_span[i][j]
                candidate_span_mask[i][j][span_start:span_end + 1] = 1
        return candidate_span_mask

    @staticmethod
    def span_tokens_to_text(source_text, tokens, span_start, span_end):
        text_with_sentiment_tokens = tokens
        predicted_start = span_start
        predicted_end = span_end

        while (predicted_start >= 0
               and text_with_sentiment_tokens[predicted_start].idx is None):
            predicted_start -= 1
        if predicted_start < 0:
            logger.warning(
                f"Could not map the token '{text_with_sentiment_tokens[span_start].text}' at index "
                f"'{span_start}' to an offset in the original text.")
            character_start = 0
        else:
            character_start = text_with_sentiment_tokens[predicted_start].idx

        while (predicted_end < len(text_with_sentiment_tokens)
               and text_with_sentiment_tokens[predicted_end].idx is None):
            predicted_end -= 1

        if predicted_end >= len(text_with_sentiment_tokens):
            print(text_with_sentiment_tokens)
            print(len(text_with_sentiment_tokens))
            print(span_end)
            print(predicted_end)
            logger.warning(
                f"Could not map the token '{text_with_sentiment_tokens[span_end].text}' at index "
                f"'{span_end}' to an offset in the original text.")
            character_end = len(source_text)
        else:
            end_token = text_with_sentiment_tokens[predicted_end]
            if end_token.idx == 0:
                character_end = (end_token.idx +
                                 len(sanitize_wordpiece(end_token.text)) + 1)
            else:
                character_end = end_token.idx + len(
                    sanitize_wordpiece(end_token.text))

        best_span_string = source_text[character_start:character_end].strip()
        return best_span_string

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        jaccard = self._jaccard.get_metric(reset)
        metrics = {
            "start_acc": self._span_start_accuracy.get_metric(reset),
            "end_acc": self._span_end_accuracy.get_metric(reset),
            "span_acc": self._span_accuracy.get_metric(reset),
            "jaccard": jaccard,
        }
        if self._candidate_span_task:
            metrics[
                "candidate_span_acc"] = self._span_classification_accuracy.get_metric(
                    reset)
            metrics["candidate_jaccard"] = self._candidate_jaccard.get_metric(
                reset)
            metrics["candidate_loss"] = self._candidate_loss_log.get_metric(
                reset)
        if self._sentiment_task:
            metrics[
                "sentiment_acc"] = self._sentiment_classification_accuracy.get_metric(
                    reset)
            metrics["sentiment_loss"] = self._sentiment_loss_log.get_metric(
                reset)
        if self._base_loss_log is not None:
            metrics["base_loss"] = self._base_loss_log.get_metric(reset)
        return metrics
Exemplo n.º 4
0
class SCIIE(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``text`` ``TextField`` we get as input to the model.
    context_layer : ``Seq2SeqEncoder``
        This layer incorporates contextual information for each word in the document.
    mention_feedforward : ``FeedForward``
        This feedforward network is applied to the span representations which is then scored
        by a linear layer.
    antecedent_feedforward: ``FeedForward``
        This feedforward network is applied to pairs of span representation, along with any
        pairwise features, which is then scored by a linear layer.
    feature_size: ``int``
        The embedding size for all the embedded features, such as distances or span widths.
    max_span_width: ``int``
        The maximum width of candidate spans.
    spans_per_word: float, required.
        A multiplier between zero and one which controls what percentage of candidate mention
        spans we retain with respect to the number of words in the document.
    max_antecedents: int, required.
        For each mention which survives the pruning stage, we consider this many antecedents.
    lexical_dropout: ``int``
        The probability of dropping out dimensions of the embedded text.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 embedding_dim: int,
                 feature_size: int,
                 max_span_width: int,
                 spans_per_word: float,
                 lexical_dropout: float = 0.2,
                 mlp_dropout: float = 0.4,
                 embedder_type=None,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SCIIE, self).__init__(vocab, regularizer)
        self.class_num = self.vocab.get_vocab_size('labels')
        word_embeddings = get_embeddings(embedder_type, self.vocab,
                                         embedding_dim, True)
        embedding_dim = word_embeddings.get_output_dim()
        self._text_field_embedder = word_embeddings

        context_layer = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(embedding_dim,
                          feature_size,
                          batch_first=True,
                          bidirectional=True))
        self._context_layer = context_layer

        endpoint_span_extractor_input_dim = context_layer.get_output_dim()
        attentive_span_extractor_input_dim = word_embeddings.get_output_dim()

        self._endpoint_span_extractor = EndpointSpanExtractor(
            endpoint_span_extractor_input_dim,
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=attentive_span_extractor_input_dim)

        # self._span_extractor = PoolingSpanExtractor(embedding_dim,
        #                                             num_width_embeddings=max_span_width,
        #                                             span_width_embedding_dim=feature_size,
        #                                             bucket_widths=False)

        entity_feedforward = FeedForward(
            self._endpoint_span_extractor.get_output_dim() +
            self._attentive_span_extractor.get_output_dim(), 2, 150, F.relu,
            mlp_dropout)
        # entity_feedforward = FeedForward(self._span_extractor.get_output_dim(), 2, 150,
        #                                  F.relu, mlp_dropout)

        feedforward_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(), 1)))
        self._mention_pruner = Pruner(feedforward_scorer)

        self._entity_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(),
                                self.class_num - 1)))

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        self._metric_all = FBetaMeasure()
        self._metric_avg = NERF1Metric()

    @overrides
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            spans: torch.IntTensor,
            labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1])

        # Shape:    (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Shape: (batch_size, num_spans_to_keep, class_num + 1)
        ne_scores = self._compute_named_entity_scores(top_span_embeddings)

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_named_entities = ne_scores.max(2)

        output_dict = {
            "top_spans": top_spans,
            "predicted_named_entities": predicted_named_entities
        }
        if labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices).squeeze(-1)
            negative_log_likelihood = F.cross_entropy(
                ne_scores.reshape(-1, self.class_num),
                pruned_gold_labels.reshape(-1))

            pruner_loss = F.binary_cross_entropy_with_logits(
                top_span_mention_scores.reshape(-1),
                (pruned_gold_labels.reshape(-1) != 0).float())
            loss = negative_log_likelihood + pruner_loss
            output_dict["loss"] = loss
            output_dict["pruner_loss"] = pruner_loss
            batch_size, _ = labels.shape
            all_scores = ne_scores.new_zeros(
                [batch_size * num_spans, self.class_num])
            all_scores[:, 0] = 1
            all_scores[flat_top_span_indices] = ne_scores.reshape(
                -1, self.class_num)
            all_scores = all_scores.reshape(
                [batch_size, num_spans, self.class_num])
            self._metric_all(all_scores, labels)
            self._metric_avg(all_scores, labels)
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False, prefix=""):
        metric = self._metric_all.get_metric(reset)
        metric2 = self._metric_avg.get_metric(reset)
        metric.update(metric2)
        return metric

    def _compute_named_entity_scores(
            self, span_embeddings: torch.FloatTensor) -> torch.Tensor:
        """
        Parameters
        ----------
        span_embeddings: ``torch.FloatTensor``, required.
            Embedding representations of spans. Has shape
            (batch_size, num_spans_to_keep, encoding_dim)
        """
        # Shape: (batch_size, num_spans_to_keep, class_num)
        scores = self._entity_scorer(span_embeddings)
        # Shape: (batch_size, num_spans_to_keep, 1)
        shape = [scores.size(0), scores.size(1), 1]
        dummy_scores = scores.new_full(shape, 0)
        ne_scores = torch.cat([dummy_scores, scores], -1)
        return ne_scores
Exemplo n.º 5
0
class SrlE2e(Model):
    """

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    model : `Union[str, BertModel]`, required.
        A string describing the BERT model to load or an already constructed BertModel.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    label_smoothing : `float`, optional (default = `0.0`)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    ignore_span_metric : `bool`, optional (default = `False`)
        Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
    srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
        The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
        which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        mention_feedforward: FeedForward,
        context_layer: Seq2SeqEncoder = None,
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        max_span_width: int = 30,
        feature_size: int = 10,
        spans_per_word: float = 100,
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("span_labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric

        self._mention_feedforward = TimeDistributed(mention_feedforward)
        self._mention_scorer = TimeDistributed(
            torch.nn.Linear(mention_feedforward.get_output_dim(), 1))

        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=self.bert_model.config.hidden_size)
        self.span_representation_dim = self._attentive_span_extractor.get_output_dim(
        )
        self._context_layer = context_layer
        if context_layer is not None:
            self._endpoint_span_extractor = EndpointSpanExtractor(
                context_layer.get_output_dim(),
                combination="x,y",
                num_width_embeddings=max_span_width,
                span_width_embedding_dim=feature_size,
                bucket_widths=False,
            )
            self.span_representation_dim = self._endpoint_span_extractor.get_output_dim(
            )

        self.hidden_layer = torch.nn.Sequential(
            torch.nn.Linear(self.span_representation_dim +
                            self.bert_model.config.hidden_size,
                            self.span_representation_dim,
                            bias=False), torch.nn.ReLU())
        self.output_layer = torch.nn.Linear(self.span_representation_dim,
                                            self.num_classes - 1,
                                            bias=False)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self._bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        initializer(self)

    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        sentence_end: torch.LongTensor,
        spans: torch.LongTensor,
        span_labels: torch.LongTensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        start = time.time()
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            # token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            bert_embeddings, spans)

        if self._context_layer is not None:
            contextualized_embeddings = self._context_layer(
                embedded_text_input, mask)
            # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
            endpoint_span_embeddings = self._endpoint_span_extractor(
                contextualized_embeddings, spans)

            # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
            # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)
            span_embeddings = endpoint_span_embeddings
        else:
            span_embeddings = attended_span_embeddings

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * sequence_length))
        num_spans = spans.shape[1]
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep)
        verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat(
            1, 1, embedded_text_input.shape[-1])
        verb_embeddings = torch.gather(embedded_text_input, 1, verb_index)
        assert len(
            verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1
        verb_embeddings = verb_embeddings.squeeze(1)
        # print(verb_indicator.sum(1, keepdim=True) > 0)
        verb_embeddings = torch.where(
            (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                1, verb_embeddings.shape[-1]), verb_embeddings,
            torch.zeros_like(verb_embeddings))
        # print(verb_embeddings)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, spans.shape[1])
        span_embeddings = util.batched_index_select(span_embeddings,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)
        top_span_labels = util.batched_index_select(
            span_labels.unsqueeze(-1), top_span_indices,
            flat_top_span_indices).squeeze(-1)
        concatenated_span_embeddings = torch.cat(
            (span_embeddings, verb_embeddings.unsqueeze(1).repeat(
                1, span_embeddings.shape[1], 1)),
            dim=2)
        # print(concatenated_span_embeddings[:,:,:])
        hidden = self.hidden_layer(concatenated_span_embeddings)
        # print(hidden[1,:,:])
        # print(top_span_indices)
        # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])])
        # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels"))
        predictions = self.output_layer(hidden)
        # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)
        predictions = torch.cat(
            (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1)
        # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1))

        output_dict = {}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]),
                                  top_span_labels.view(-1)) *
                    top_span_mask.float().view(-1)
                    ).sum() / top_span_mask.float().sum()
            # print(top_span_labels)
            # print(predictions.argmax(-1))
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.get_tags(
                    top_spans, predictions, mask.shape[1], top_span_mask,
                    output_dict)
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print('G', batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict

    def get_tags(self, spans, logits, sequence_length, span_mask, output_dict):
        predicted_tag_ids = logits.argmax(2)
        predicted_tags = []
        for i in range(spans.shape[0]):
            sequence = ["O" for _ in range(sequence_length)]
            for j in range(spans.shape[1]):
                if span_mask[i, j].item() == 0:
                    continue
                tag = predicted_tag_ids[i, j].item()
                if tag != self.vocab.get_token_index("O",
                                                     namespace="span_labels"):
                    start = spans[i, j, 0].item()
                    end = spans[i, j, 1].item()
                    if all([el == "O" for el in sequence[start:end + 1]]):
                        tag = self.vocab.get_token_from_index(
                            tag, namespace="span_labels")
                        sequence[start] = "B-" + tag
                        for index in range(start + 1, end + 1):
                            sequence[index] = "I-" + tag
            predicted_tags.append(
                [sequence[ind] for ind in output_dict["wordpiece_offsets"][i]])
        print(predicted_tags)
        return predicted_tags

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.

        NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
        decoding produces low quality output if you decode on top of word representations directly,
        because the model gets confused by the 'missing' positions (which is sensible as it is trained
        to perform tagging on wordpieces, not words).

        Secondly, it's important that the indices we use to recover words from the wordpieces are the
        start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
        tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
        when we select out the word tags from the wordpiece tags. This happens in the case that a word
        is split into multiple word pieces, and then we take the last tag of the word, which might
        correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        wordpiece_tags = []
        word_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        # **************** Different ********************
        # We add in the offsets here so we can compute the un-wordpieced tags.
        for predictions, length, offsets in zip(
                predictions_list, sequence_lengths,
                output_dict["wordpiece_offsets"]):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(x, namespace="labels")
                for x in max_likelihood_sequence
            ]

            wordpiece_tags.append(tags)
            word_tags.append([tags[i] for i in offsets])
        output_dict["wordpiece_tags"] = wordpiece_tags
        output_dict["tags"] = word_tags
        return output_dict

    def get_metrics(self, reset: bool = False):
        if self.ignore_span_metric:
            # Return an empty dictionary if ignoring the
            # span metric
            return {}

        else:
            metric_dict = self.span_metric.get_metric(reset=reset)

            # This can be a lot of metrics, as there are 3 per class.
            # we only really care about the overall metrics, so we filter for them here.
            return {x: y for x, y in metric_dict.items() if "overall" in x}

    def get_viterbi_pairwise_potentials(self):
        """
        Generate a matrix of pairwise transition potentials for the BIO labels.
        The only constraint implemented here is that I-XXX labels must be preceded
        by either an identical I-XXX tag or a B-XXX tag. In order to achieve this
        constraint, pairs of labels which do not satisfy this constraint have a
        pairwise potential of -inf.

        # Returns

        transition_matrix : `torch.Tensor`
            A `(num_labels, num_labels)` matrix of pairwise potentials.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])

        for i, previous_label in all_labels.items():
            for j, label in all_labels.items():
                # I labels can only be preceded by themselves or
                # their corresponding B tag.
                if i != j and label[
                        0] == "I" and not previous_label == "B" + label[1:]:
                    transition_matrix[i, j] = float("-inf")
        return transition_matrix

    def get_start_transitions(self):
        """
        In the BIO sequence, we cannot start the sequence with an I-XXX tag.
        This transition sequence is passed to viterbi_decode to specify this constraint.

        # Returns

        start_transitions : `torch.Tensor`
            The pairwise potentials between a START token and
            the first token of the sequence.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)

        start_transitions = torch.zeros(num_labels)

        for i, label in all_labels.items():
            if label[0] == "I":
                start_transitions[i] = float("-inf")

        return start_transitions

    default_predictor = "semantic_role_labeling"