def test_span_metrics_are_computed_correctly(self):
        from allennlp_models.structured_prediction.models.srl import (
            convert_bio_tags_to_conll_format, )

        batch_verb_indices = [2]
        batch_sentences = [["The", "cat", "loves", "hats", "."]]
        batch_bio_predicted_tags = [["B-ARG0", "B-ARG1", "B-V", "B-ARG1", "O"]]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [["B-ARG0", "I-ARG0", "B-V", "B-ARG1", "O"]]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer(ignore_classes=["V"])
        srl_scorer(batch_verb_indices, batch_sentences,
                   batch_conll_predicted_tags, batch_conll_gold_tags)
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 9
        assert_allclose(metrics["precision-ARG0"], 0.0)
        assert_allclose(metrics["recall-ARG0"], 0.0)
        assert_allclose(metrics["f1-measure-ARG0"], 0.0)
        assert_allclose(metrics["precision-ARG1"], 0.5)
        assert_allclose(metrics["recall-ARG1"], 1.0)
        assert_allclose(metrics["f1-measure-ARG1"], 2 / 3)
        assert_allclose(metrics["precision-overall"], 1 / 3)
        assert_allclose(metrics["recall-overall"], 1 / 2)
        assert_allclose(metrics["f1-measure-overall"],
                        (2 * (1 / 3) * (1 / 2)) / ((1 / 3) + (1 / 2)))
Exemple #2
0
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric
        initializer(self)
Exemple #3
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        binary_feature_dim: int,
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")

        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None

        self.encoder = encoder
        # There are exactly 2 binary features for the verb predicate embedding.
        self.binary_feature_embedding = Embedding(
            num_embeddings=2, embedding_dim=binary_feature_dim)
        self.tag_projection_layer = TimeDistributed(
            Linear(self.encoder.get_output_dim(), self.num_classes))
        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric

        check_dimensions_match(
            text_field_embedder.get_output_dim() + binary_feature_dim,
            encoder.get_input_dim(),
            "text embedding dim + verb indicator embedding dim",
            "encoder input dim",
        )
        initializer(self)
Exemple #4
0
 def __init__(
     self,
     vocab: Vocabulary,
     bert_model: Union[str, AutoModel],
     embedding_dropout: float = 0.0,
     initializer: InitializerApplicator = InitializerApplicator(),
     label_smoothing: float = None,
     ignore_span_metric: bool = False,
     srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
     restrict_frames: bool = False,
     restrict_roles: bool = False,
     inventory: str = "verbatlas",
     **kwargs,
 ) -> None:
     # bypass SrlBert constructor
     Model.__init__(self, vocab, **kwargs)
     self.lemma_frame_dict = load_lemma_frame(LEMMA_FRAME_PATH)
     self.frame_role_dict = load_role_frame(FRAME_ROLE_PATH)
     self.restrict_frames = restrict_frames
     self.restrict_roles = restrict_roles
     self.transformer = AutoModel.from_pretrained(bert_model)
     self.frame_criterion = nn.CrossEntropyLoss()
     if inventory == "verbatlas":
         # add missing labels
         frame_list = load_label_list(FRAME_LIST_PATH)
         self.vocab.add_tokens_to_namespace(frame_list, "frames_labels")
     self.num_classes = self.vocab.get_vocab_size("labels")
     self.frame_num_classes = self.vocab.get_vocab_size("frames_labels")
     if srl_eval_path is not None:
         # For the span based evaluation, we don't want to consider labels
         # for verb, because the verb index is provided to the model.
         self.span_metric = SrlEvalScorer(srl_eval_path, ignore_classes=["V"])
     else:
         self.span_metric = None
     self.f1_frame_metric = FBetaMeasure(average="micro")
     self.tag_projection_layer = nn.Linear(self.transformer.config.hidden_size, self.num_classes)
     self.frame_projection_layer = nn.Linear(
         self.transformer.config.hidden_size, self.frame_num_classes
     )
     self.embedding_dropout = nn.Dropout(p=embedding_dropout)
     self._label_smoothing = label_smoothing
     self.ignore_span_metric = ignore_span_metric
     initializer(self)
Exemple #5
0
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, Dict[str, Any], BertModel],
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        elif isinstance(bert_model, dict):
            warnings.warn(
                "Initializing BertModel without pretrained weights. This is fine if you're loading "
                "from an AllenNLP archive, but not if you're training.",
                UserWarning,
            )
            bert_config = BertConfig.from_dict(bert_model)
            self.bert_model = BertModel(bert_config)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric
        initializer(self)
    def test_distributed_setting_throws_an_error(self):
        from allennlp_models.structured_prediction.models.srl import (
            convert_bio_tags_to_conll_format, )

        batch_verb_indices = [2]
        batch_sentences = [["The", "cat", "loves", "hats", "."]]
        batch_bio_predicted_tags = [["B-ARG0", "B-ARG1", "B-V", "B-ARG1", "O"]]
        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [["B-ARG0", "I-ARG0", "B-V", "B-ARG1", "O"]]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        metric_kwargs = {
            "batch_verb_indices": [batch_verb_indices, batch_verb_indices],
            "batch_sentences": [batch_sentences, batch_sentences],
            "batch_conll_formatted_predicted_tags": [
                batch_conll_predicted_tags,
                batch_conll_predicted_tags,
            ],
            "batch_conll_formatted_gold_tags":
            [batch_conll_gold_tags, batch_conll_gold_tags],
        }

        desired_values = {}  # it does not matter, we expect the run to fail.

        with pytest.raises(Exception) as exc:
            run_distributed_test(
                [-1, -1],
                global_distributed_metric,
                SrlEvalScorer(ignore_classes=["V"]),
                metric_kwargs,
                desired_values,
                exact=True,
            )
            assert (
                "RuntimeError: Distributed aggregation for `SrlEvalScorer` is currently not supported."
                in str(exc.value))
Exemple #7
0
class TransformerSrlSpan(SrlBert):
    """

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    model : `Union[str, AutoModel]`, required.
        A string describing the BERT model to load or an already constructed AutoModel.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    label_smoothing : `float`, optional (default = `0.0`)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    ignore_span_metric : `bool`, optional (default = `False`)
        Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
    srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
        The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
        which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, AutoModel],
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        inventory: str = "verbatlas",
        **kwargs,
    ) -> None:
        # bypass SrlBert constructor
        Model.__init__(self, vocab, **kwargs)
        self.transformer = AutoModel.from_pretrained(bert_model)
        self.frame_criterion = nn.CrossEntropyLoss()
        if inventory == "verbatlas":
            # add missing frame labels
            frame_list = load_label_list(FRAME_LIST_PATH)
            self.vocab.add_tokens_to_namespace(frame_list, "frames_labels")
            # add missing role labels
            role_list = load_label_list(ROLE_LIST_PATH)
            self.vocab.add_tokens_to_namespace(role_list, "labels")
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.frame_num_classes = self.vocab.get_vocab_size("frames_labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.f1_frame_metric = FBetaMeasure(average="micro")
        self.tag_projection_layer = nn.Linear(
            self.transformer.config.hidden_size, self.num_classes)
        self.frame_projection_layer = nn.Linear(
            self.transformer.config.hidden_size, self.frame_num_classes)
        self.embedding_dropout = nn.Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric
        initializer(self)

    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        frame_indicator: torch.Tensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
        frame_tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        frame_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the frame
            in the sentence. This should have shape (batch_size, num_tokens). Similar to verb_indicator,
            but handles bert wordpiece tokenizer by cosnidering a frame only the first subtoken.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        frame_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the gold frames
            of shape ``(batch_size, num_tokens)``
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        input_ids = util.get_token_ids_from_text_field_tensors(tokens)
        bert_embeddings, _ = self.transformer(
            input_ids=input_ids,
            token_type_ids=verb_indicator,
            attention_mask=mask,
            return_dict=False,
        )
        # extract embeddings
        embedded_text_input = self.embedding_dropout(bert_embeddings)
        frame_embeddings = embedded_text_input[frame_indicator == 1]
        # get sizes
        batch_size, sequence_length, _ = embedded_text_input.size()
        # outputs
        logits = self.tag_projection_layer(embedded_text_input)
        frame_logits = self.frame_projection_layer(frame_embeddings)

        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])

        frame_probabilities = F.softmax(frame_logits, dim=-1)
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict = {
            "logits": logits,
            "frame_logits": frame_logits,
            "class_probabilities": class_probabilities,
            "frame_probabilities": frame_probabilities,
            "mask": mask,
        }
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        lemmas = [l for x in metadata for l in x["lemmas"]]
        output_dict["words"] = list(words)
        output_dict["lemma"] = list(lemmas)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            # compute role loss
            role_loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            # compute frame loss
            frame_tags_filtered = frame_tags[frame_indicator == 1]
            frame_loss = self.frame_criterion(frame_logits,
                                              frame_tags_filtered)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            self.f1_frame_metric(frame_logits, frame_tags_filtered)
            output_dict["frame_loss"] = frame_loss
            output_dict["role_loss"] = role_loss
            output_dict["loss"] = (role_loss + frame_loss) / 2
        return output_dict

    def decode_frames(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        # frame prediction
        frame_probabilities = output_dict["frame_probabilities"]
        frame_predictions = frame_probabilities.argmax(
            dim=-1).cpu().data.numpy()
        output_dict["frame_tags"] = [
            self.vocab.get_token_from_index(f, namespace="frames_labels")
            for f in frame_predictions
        ]
        output_dict["frame_scores"] = [
            fp[f] for f, fp in zip(frame_predictions, frame_probabilities)
        ]
        return output_dict

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        output_dict = self.decode_frames(output_dict)
        output_dict = super().make_output_human_readable(output_dict)
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False):
        if self.ignore_span_metric:
            # Return an empty dictionary if ignoring the
            # span metric
            return {}

        else:
            metric_dict = self.span_metric.get_metric(reset=reset)
            frame_metric_dict = self.f1_frame_metric.get_metric(reset=reset)
            # This can be a lot of metrics, as there are 3 per class.
            # we only really care about the overall metrics, so we filter for them here.
            metric_dict_filtered = {
                x.split("-")[0] + "_role": y
                for x, y in metric_dict.items() if "overall" in x
            }
            frame_metric_dict = {
                x + "_frame": y
                for x, y in frame_metric_dict.items()
            }
            return {**metric_dict_filtered, **frame_metric_dict}

    def _get_label_tokens(self, namespace: str = "labels"):
        return self.vocab.get_token_to_index_vocabulary(namespace).keys()

    def _get_label_ids(self, namespace: str = "labels"):
        return self.vocab.get_index_to_token_vocabulary(namespace).keys()

    default_predictor = "transformer_srl"
Exemple #8
0
class SemanticRoleLabeler(Model):
    """
    This model performs semantic role labeling using BIO tags using Propbank semantic roles.
    Specifically, it is an implementation of [Deep Semantic Role Labeling - What works
    and what's next](https://www.aclweb.org/anthology/P17-1044).

    This implementation is effectively a series of stacked interleaved LSTMs with highway
    connections, applied to embedded sequences of words concatenated with a binary indicator
    containing whether or not a word is the verbal predicate to generate predictions for in
    the sentence. Additionally, during inference, Viterbi decoding is applied to constrain
    the predictions to contain valid BIO sequences.

    Specifically, the model expects and outputs IOB2-formatted tags, where the
    B- tag is used in the beginning of every chunk (i.e. all chunks start with the B- tag).

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : `TextFieldEmbedder`, required
        Used to embed the `tokens` `TextField` we get as input to the model.
    encoder : `Seq2SeqEncoder`
        The encoder (with its own internal stacking) that we will use in between embedding tokens
        and predicting output tags.
    binary_feature_dim : int, required.
        The dimensionality of the embedding of the binary verb predicate features.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    label_smoothing : `float`, optional (default = 0.0)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    ignore_span_metric : `bool`, optional (default = False)
        Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
    srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
        The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
        which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        encoder: Seq2SeqEncoder,
        binary_feature_dim: int,
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")

        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None

        self.encoder = encoder
        # There are exactly 2 binary features for the verb predicate embedding.
        self.binary_feature_embedding = Embedding(
            num_embeddings=2, embedding_dim=binary_feature_dim)
        self.tag_projection_layer = TimeDistributed(
            Linear(self.encoder.get_output_dim(), self.num_classes))
        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric

        check_dimensions_match(
            text_field_embedder.get_output_dim() + binary_feature_dim,
            encoder.get_input_dim(),
            "text embedding dim + verb indicator embedding dim",
            "encoder input dim",
        )
        initializer(self)

    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.LongTensor,
        tags: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : TextFieldTensors, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. This output is a dictionary mapping keys to `TokenIndexer`
            tensors.  At its most basic, using a `SingleIdTokenIndexer` this is : `{"tokens":
            Tensor(batch_size, num_tokens)}`. This dictionary will have the same keys as were used
            for the `TokenIndexers` when you created the `TextField` representing your
            sequence.  The dictionary is designed to be passed directly to a `TextFieldEmbedder`,
            which knows how to combine different word representations into a single vector per
            token in your input.
        verb_indicator: torch.LongTensor, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = None)
            metadata containg the original words in the sentence and the verb to compute the
            frame for, under 'words' and 'verb' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.

        """
        embedded_text_input = self.embedding_dropout(
            self.text_field_embedder(tokens))
        mask = get_text_field_mask(tokens)
        embedded_verb_indicator = self.binary_feature_embedding(
            verb_indicator.long())
        # Concatenate the verb feature onto the embedded text. This now
        # has shape (batch_size, sequence_length, embedding_dim + binary_feature_dim).
        embedded_text_with_verb_indicator = torch.cat(
            [embedded_text_input, embedded_verb_indicator], -1)
        batch_size, sequence_length, _ = embedded_text_with_verb_indicator.size(
        )

        encoded_text = self.encoder(embedded_text_with_verb_indicator, mask)

        logits = self.tag_projection_layer(encoded_text)
        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss

        words, verbs = zip(*[(x["words"], x["verb"]) for x in metadata])
        if metadata is not None:
            output_dict["words"] = list(words)
            output_dict["verb"] = list(verbs)
        return output_dict

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        all_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        for predictions, length in zip(predictions_list, sequence_lengths):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(x, namespace="labels")
                for x in max_likelihood_sequence
            ]
            all_tags.append(tags)
        output_dict["tags"] = all_tags
        return output_dict

    def get_metrics(self, reset: bool = False):
        if self.ignore_span_metric:
            # Return an empty dictionary if ignoring the
            # span metric
            return {}

        else:
            metric_dict = self.span_metric.get_metric(reset=reset)

            # This can be a lot of metrics, as there are 3 per class.
            # we only really care about the overall metrics, so we filter for them here.
            return {x: y for x, y in metric_dict.items() if "overall" in x}

    def get_viterbi_pairwise_potentials(self):
        """
        Generate a matrix of pairwise transition potentials for the BIO labels.
        The only constraint implemented here is that I-XXX labels must be preceded
        by either an identical I-XXX tag or a B-XXX tag. In order to achieve this
        constraint, pairs of labels which do not satisfy this constraint have a
        pairwise potential of -inf.

        # Returns

        transition_matrix : torch.Tensor
            A (num_labels, num_labels) matrix of pairwise potentials.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])

        for i, previous_label in all_labels.items():
            for j, label in all_labels.items():
                # I labels can only be preceded by themselves or
                # their corresponding B tag.
                if i != j and label[
                        0] == "I" and not previous_label == "B" + label[1:]:
                    transition_matrix[i, j] = float("-inf")
        return transition_matrix

    def get_start_transitions(self):
        """
        In the BIO sequence, we cannot start the sequence with an I-XXX tag.
        This transition sequence is passed to viterbi_decode to specify this constraint.

        # Returns

        start_transitions : torch.Tensor
            The pairwise potentials between a START token and
            the first token of the sequence.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)

        start_transitions = torch.zeros(num_labels)

        for i, label in all_labels.items():
            if label[0] == "I":
                start_transitions[i] = float("-inf")

        return start_transitions

    default_predictor = "semantic-role-labeling"
    def test_srl_eval_correctly_scores_identical_tags(self):
        batch_verb_indices = [3, 8, 2, 0]
        batch_sentences = [
            [
                "Mali",
                "government",
                "officials",
                "say",
                "the",
                "woman",
                "'s",
                "confession",
                "was",
                "forced",
                ".",
            ],
            [
                "Mali",
                "government",
                "officials",
                "say",
                "the",
                "woman",
                "'s",
                "confession",
                "was",
                "forced",
                ".",
            ],
            [
                "The",
                "prosecution",
                "rested",
                "its",
                "case",
                "last",
                "month",
                "after",
                "four",
                "months",
                "of",
                "hearings",
                ".",
            ],
            ["Come", "in", "and", "buy", "."],
        ]
        batch_bio_predicted_tags = [
            [
                "B-ARG0",
                "I-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "O",
            ],
            [
                "O", "O", "O", "O", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1",
                "B-V", "B-ARG2", "O"
            ],
            [
                "B-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "O",
            ],
            ["B-V", "B-AM-DIR", "O", "O", "O"],
        ]
        from allennlp_models.structured_prediction.models.srl import (
            convert_bio_tags_to_conll_format, )

        batch_conll_predicted_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_predicted_tags
        ]
        batch_bio_gold_tags = [
            [
                "B-ARG0",
                "I-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "I-ARG1",
                "O",
            ],
            [
                "O", "O", "O", "O", "B-ARG1", "I-ARG1", "I-ARG1", "I-ARG1",
                "B-V", "B-ARG2", "O"
            ],
            [
                "B-ARG0",
                "I-ARG0",
                "B-V",
                "B-ARG1",
                "I-ARG1",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "B-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "I-ARGM-TMP",
                "O",
            ],
            ["B-V", "B-AM-DIR", "O", "O", "O"],
        ]
        batch_conll_gold_tags = [
            convert_bio_tags_to_conll_format(tags)
            for tags in batch_bio_gold_tags
        ]

        srl_scorer = SrlEvalScorer(ignore_classes=["V"])
        srl_scorer(batch_verb_indices, batch_sentences,
                   batch_conll_predicted_tags, batch_conll_gold_tags)
        metrics = srl_scorer.get_metric()
        assert len(metrics) == 18
        assert_allclose(metrics["precision-ARG0"], 1.0)
        assert_allclose(metrics["recall-ARG0"], 1.0)
        assert_allclose(metrics["f1-measure-ARG0"], 1.0)
        assert_allclose(metrics["precision-ARG1"], 1.0)
        assert_allclose(metrics["recall-ARG1"], 1.0)
        assert_allclose(metrics["f1-measure-ARG1"], 1.0)
        assert_allclose(metrics["precision-ARG2"], 1.0)
        assert_allclose(metrics["recall-ARG2"], 1.0)
        assert_allclose(metrics["f1-measure-ARG2"], 1.0)
        assert_allclose(metrics["precision-ARGM-TMP"], 1.0)
        assert_allclose(metrics["recall-ARGM-TMP"], 1.0)
        assert_allclose(metrics["f1-measure-ARGM-TMP"], 1.0)
        assert_allclose(metrics["precision-AM-DIR"], 1.0)
        assert_allclose(metrics["recall-AM-DIR"], 1.0)
        assert_allclose(metrics["f1-measure-AM-DIR"], 1.0)
        assert_allclose(metrics["precision-overall"], 1.0)
        assert_allclose(metrics["recall-overall"], 1.0)
        assert_allclose(metrics["f1-measure-overall"], 1.0)
Exemple #10
0
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, AutoModel],
        mismatched_embedder: TokenEmbedder = None,
        lp: bool = False,
        lpsmap: bool = False,
        lpsmap_core_roles_only: bool = True,
        validation_inference: bool = True,
        batch_size: int = None,
        encoder: Seq2SeqEncoder = None,
        reinitialize_pos_embedding: bool = False,
        embedding_dropout: float = 0.0,
        mlp_hidden_size: int = 300,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        label_encoding: str = "BIO",
        constrain_crf_decoding: bool = None,
        include_start_end_transitions: bool = True,
        label_namespace: str = "labels",
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            if mismatched_embedder is None:
                self.bert_model = AutoModel.from_pretrained(bert_model)
            self.bert_config = AutoConfig.from_pretrained(bert_model)
        else:
            if mismatched_embedder is None:
                self.bert_model = bert_model
            self.bert_config = bert_model.config
        if reinitialize_pos_embedding:
            self.bert_model._init_weights(
                self.bert_model.embeddings.position_embeddings)
            # self.bert_model._init_weights(self.bert_model.embeddings.token_type_embeddings)
        if mismatched_embedder is not None:
            self.bert_model = mismatched_embedder

        self._label_namespace = label_namespace
        self.num_classes = self.vocab.get_vocab_size(label_namespace)
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None

        if constrain_crf_decoding is None:
            constrain_crf_decoding = label_encoding is not None

        self.label_encoding = label_encoding
        self.constrain_crf_decoding = constrain_crf_decoding
        if constrain_crf_decoding:
            if not label_encoding:
                raise ConfigurationError(
                    "constrain_crf_decoding is True, but no label_encoding was specified."
                )
            labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
            constraints = allowed_transitions(label_encoding, labels)
        else:
            constraints = None

        self.include_start_end_transitions = include_start_end_transitions
        self.crf = ConditionalRandomField(
            self.num_classes,
            constraints,
            include_start_end_transitions=include_start_end_transitions)
        self._encoder = encoder
        representation_size = self.bert_config.hidden_size
        if self.bert_config.type_vocab_size == 1:
            representation_size = self.bert_config.hidden_size * 2
        if encoder is None:
            self.tag_projection_layer = torch.nn.Sequential(
                Linear(representation_size, mlp_hidden_size), torch.nn.ReLU(),
                Linear(mlp_hidden_size, self.num_classes))
        else:
            self.tag_projection_layer = torch.nn.Sequential(
                Linear(encoder.get_output_dim() * 2, mlp_hidden_size),
                torch.nn.ReLU(), Linear(mlp_hidden_size, self.num_classes))

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self.predicate_embedding = torch.nn.Embedding(num_embeddings=2,
                                                      embedding_dim=10)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric
        self._lp = lp
        self._lpsmap = lpsmap
        self._lpsmap_core_only = lpsmap_core_roles_only
        self._val_inference = validation_inference
        if self._lpsmap:
            self._core_roles = []
            for i in range(6):
                try:
                    self._core_roles.append(
                        self.vocab.get_token_index(
                            "B-ARG" + str(i), namespace=self._label_namespace))
                except:
                    logger.info("B-ARG" + str(i) + " is not in labels")
            self._r_roles = []
            self._c_roles = []
            for i in range(self.num_classes):
                token = self.vocab.get_token_from_index(
                    i, namespace=self._label_namespace)
                if token[:4] == "B-R-" and token[4:] != "ARG1":
                    try:
                        base_arg_index = self.vocab.get_token_index(
                            "B-" + token[4:], namespace=self._label_namespace)
                        self._r_roles.append((i, base_arg_index))
                    except:
                        logger.info("B-" + token[4:] + " is not in labels")
                elif token[:4] == "B-C-" and token[4:] != "ARG1":
                    try:
                        base_arg_index = self.vocab.get_token_index(
                            "B-" + token[4:], namespace=self._label_namespace)
                        self._c_roles.append((i, base_arg_index))
                    except:
                        logger.info("B-" + token[4:] + " is not in labels")
            # self._core_roles = [index for index in range(self.vocab.get_vocab_size("labels")) if index in [self.vocab.get_token_index("B-ARG"+str(i), namespace="labels") for i in range(3)]]
            self.lpsmap = None
        if lp:
            """self._layer_list = []
            self.length_map = {}
            self.lengths = []
            for max_sequence_length in [70, 100, 200, 300]:
                x = cp.Variable((max_sequence_length, self.vocab.get_vocab_size(namespace="labels")))
                S = cp.Parameter((max_sequence_length, self.vocab.get_vocab_size(namespace="labels")))
                constraints = [x >= 0, cp.sum(x, axis=1) == 1]
                objective = cp.Maximize(cp.sum(cp.multiply(x, S)))
                problem = cp.Problem(objective, constraints)
                assert problem.is_dpp()
                lp_layer = CvxpyLayer(problem, parameters=[S], variables=[x])
                self._layer_list.append(lp_layer)
                self.length_map[max_sequence_length] = len(self._layer_list)-1
                self.lengths.append(max_sequence_length)
            self._layer_list = torch.nn.ModuleList(self._layer_list)"""
            pass
        initializer(self)
Exemple #11
0
class SrlBert(Model):
    """

    A BERT based model [Simple BERT Models for Relation Extraction and Semantic Role Labeling (Shi et al, 2019)]
    (https://arxiv.org/abs/1904.05255) with some modifications (no additional parameters apart from a linear
    classification layer), which is currently the state-of-the-art single model for English PropBank SRL
    (Newswire sentences).

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    model : `Union[str, BertModel]`, required.
        A string describing the BERT model to load or an already constructed BertModel.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    label_smoothing : `float`, optional (default = `0.0`)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    ignore_span_metric : `bool`, optional (default = `False`)
        Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
    srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
        The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
        which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, AutoModel],
        mismatched_embedder: TokenEmbedder = None,
        lp: bool = False,
        lpsmap: bool = False,
        lpsmap_core_roles_only: bool = True,
        validation_inference: bool = True,
        batch_size: int = None,
        encoder: Seq2SeqEncoder = None,
        reinitialize_pos_embedding: bool = False,
        embedding_dropout: float = 0.0,
        mlp_hidden_size: int = 300,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        label_encoding: str = "BIO",
        constrain_crf_decoding: bool = None,
        include_start_end_transitions: bool = True,
        label_namespace: str = "labels",
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            if mismatched_embedder is None:
                self.bert_model = AutoModel.from_pretrained(bert_model)
            self.bert_config = AutoConfig.from_pretrained(bert_model)
        else:
            if mismatched_embedder is None:
                self.bert_model = bert_model
            self.bert_config = bert_model.config
        if reinitialize_pos_embedding:
            self.bert_model._init_weights(
                self.bert_model.embeddings.position_embeddings)
            # self.bert_model._init_weights(self.bert_model.embeddings.token_type_embeddings)
        if mismatched_embedder is not None:
            self.bert_model = mismatched_embedder

        self._label_namespace = label_namespace
        self.num_classes = self.vocab.get_vocab_size(label_namespace)
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None

        if constrain_crf_decoding is None:
            constrain_crf_decoding = label_encoding is not None

        self.label_encoding = label_encoding
        self.constrain_crf_decoding = constrain_crf_decoding
        if constrain_crf_decoding:
            if not label_encoding:
                raise ConfigurationError(
                    "constrain_crf_decoding is True, but no label_encoding was specified."
                )
            labels = self.vocab.get_index_to_token_vocabulary(label_namespace)
            constraints = allowed_transitions(label_encoding, labels)
        else:
            constraints = None

        self.include_start_end_transitions = include_start_end_transitions
        self.crf = ConditionalRandomField(
            self.num_classes,
            constraints,
            include_start_end_transitions=include_start_end_transitions)
        self._encoder = encoder
        representation_size = self.bert_config.hidden_size
        if self.bert_config.type_vocab_size == 1:
            representation_size = self.bert_config.hidden_size * 2
        if encoder is None:
            self.tag_projection_layer = torch.nn.Sequential(
                Linear(representation_size, mlp_hidden_size), torch.nn.ReLU(),
                Linear(mlp_hidden_size, self.num_classes))
        else:
            self.tag_projection_layer = torch.nn.Sequential(
                Linear(encoder.get_output_dim() * 2, mlp_hidden_size),
                torch.nn.ReLU(), Linear(mlp_hidden_size, self.num_classes))

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self.predicate_embedding = torch.nn.Embedding(num_embeddings=2,
                                                      embedding_dim=10)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric
        self._lp = lp
        self._lpsmap = lpsmap
        self._lpsmap_core_only = lpsmap_core_roles_only
        self._val_inference = validation_inference
        if self._lpsmap:
            self._core_roles = []
            for i in range(6):
                try:
                    self._core_roles.append(
                        self.vocab.get_token_index(
                            "B-ARG" + str(i), namespace=self._label_namespace))
                except:
                    logger.info("B-ARG" + str(i) + " is not in labels")
            self._r_roles = []
            self._c_roles = []
            for i in range(self.num_classes):
                token = self.vocab.get_token_from_index(
                    i, namespace=self._label_namespace)
                if token[:4] == "B-R-" and token[4:] != "ARG1":
                    try:
                        base_arg_index = self.vocab.get_token_index(
                            "B-" + token[4:], namespace=self._label_namespace)
                        self._r_roles.append((i, base_arg_index))
                    except:
                        logger.info("B-" + token[4:] + " is not in labels")
                elif token[:4] == "B-C-" and token[4:] != "ARG1":
                    try:
                        base_arg_index = self.vocab.get_token_index(
                            "B-" + token[4:], namespace=self._label_namespace)
                        self._c_roles.append((i, base_arg_index))
                    except:
                        logger.info("B-" + token[4:] + " is not in labels")
            # self._core_roles = [index for index in range(self.vocab.get_vocab_size("labels")) if index in [self.vocab.get_token_index("B-ARG"+str(i), namespace="labels") for i in range(3)]]
            self.lpsmap = None
        if lp:
            """self._layer_list = []
            self.length_map = {}
            self.lengths = []
            for max_sequence_length in [70, 100, 200, 300]:
                x = cp.Variable((max_sequence_length, self.vocab.get_vocab_size(namespace="labels")))
                S = cp.Parameter((max_sequence_length, self.vocab.get_vocab_size(namespace="labels")))
                constraints = [x >= 0, cp.sum(x, axis=1) == 1]
                objective = cp.Maximize(cp.sum(cp.multiply(x, S)))
                problem = cp.Problem(objective, constraints)
                assert problem.is_dpp()
                lp_layer = CvxpyLayer(problem, parameters=[S], variables=[x])
                self._layer_list.append(lp_layer)
                self.length_map[max_sequence_length] = len(self._layer_list)-1
                self.lengths.append(max_sequence_length)
            self._layer_list = torch.nn.ModuleList(self._layer_list)"""
            pass
        initializer(self)

    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            verb_indicator: torch.Tensor,
            sentence_end: torch.LongTensor,
            metadata: List[Any],
            tags: torch.LongTensor = None,
            offsets: torch.LongTensor = None):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containing the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """

        if isinstance(self.bert_model,
                      PretrainedTransformerMismatchedEmbedder):
            encoder_inputs = tokens["tokens"]
            if self.bert_config.type_vocab_size > 1:
                encoder_inputs["type_ids"] = verb_indicator
            encoded_text = self.bert_model(**encoder_inputs)
            batch_size = encoded_text.shape[0]
            if self.bert_config.type_vocab_size == 1:
                verb_embeddings = encoded_text[
                    torch.arange(batch_size).to(encoded_text.device),
                    verb_indicator.argmax(1), :]
                verb_embeddings = torch.where(
                    (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                        1, verb_embeddings.shape[-1]), verb_embeddings,
                    torch.zeros_like(verb_embeddings))
                encoded_text = torch.cat(
                    (encoded_text, verb_embeddings.unsqueeze(1).repeat(
                        1, encoded_text.shape[1], 1)),
                    dim=2)
            mask = tokens["tokens"]["mask"]
            index = mask.sum(1).argmax().item()
            # print(mask.shape, encoded_text.shape, tokens["tokens"]["token_ids"].shape, tags.shape, max([len(x['words']) for x in metadata]), mask.sum(1)[index].item())
            # print(tokens["tokens"]["token_ids"][index,:])
        else:
            mask = get_text_field_mask(tokens)
            bert_embeddings, _ = self.bert_model(
                input_ids=util.get_token_ids_from_text_field_tensors(tokens),
                # token_type_ids=verb_indicator,
                attention_mask=mask,
            )

            batch_size, _ = mask.size()
            embedded_text_input = self.embedding_dropout(bert_embeddings)
            # Restrict to sentence part
            sentence_mask = (torch.arange(mask.shape[1]).unsqueeze(0).repeat(
                batch_size, 1).to(mask.device) <
                             sentence_end.unsqueeze(1).repeat(
                                 1, mask.shape[1])).long()
            cutoff = sentence_end.max().item()
            if self._encoder is None:
                encoded_text = embedded_text_input
                mask = sentence_mask[:, :cutoff].contiguous()
                encoded_text = encoded_text[:, :cutoff, :]
                tags = tags[:, :cutoff].contiguous()
            else:
                predicate_embeddings = self.predicate_embedding(verb_indicator)
                encoder_inputs = torch.cat(
                    (embedded_text_input, predicate_embeddings), dim=-1)
                encoded_text = self._encoder(encoder_inputs,
                                             mask=sentence_mask.bool())
                # print(verb_indicator)
                predicate_index = (verb_indicator * torch.arange(
                    start=verb_indicator.shape[-1] - 1, end=-1,
                    step=-1).to(mask.device).unsqueeze(0).repeat(
                        batch_size, 1)).argmax(1)
                # print(predicate_index)
                predicate_hidden = encoded_text[
                    torch.arange(batch_size).to(mask.device), predicate_index]
                predicate_exists, _ = verb_indicator.max(1)
                encoded_text = encoded_text[:, :cutoff, :]
                tags = tags[:, :cutoff].contiguous()
                mask = sentence_mask[:, :cutoff].contiguous()
                predicate_exists = predicate_exists.unsqueeze(1).repeat(
                    1, encoded_text.shape[-1])
                predicate_hidden = torch.where(
                    predicate_exists > 0, predicate_hidden,
                    torch.zeros_like(predicate_hidden))
                encoded_text = torch.cat(
                    (encoded_text, predicate_hidden.unsqueeze(1).repeat(
                        1, encoded_text.shape[1], 1)),
                    dim=-1)

        sequence_length = encoded_text.shape[1]
        logits = self.tag_projection_layer(encoded_text)
        # print(mask, logits)
        if self._lp and sequence_length <= 100:
            eps = 1e-4
            Q = eps * torch.eye(
                sequence_length * self.num_classes,
                sequence_length * self.num_classes).unsqueeze(0).repeat(
                    batch_size, 1, 1).to(logits.device).float()
            p = logits.view(batch_size, -1)
            G = -1 * torch.eye(
                sequence_length * self.num_classes).unsqueeze(0).repeat(
                    batch_size, 1, 1).to(logits.device).float()
            h = torch.zeros_like(p)
            A = torch.arange(sequence_length *
                             self.num_classes).unsqueeze(0).repeat(
                                 sequence_length, 1)
            A2 = torch.arange(sequence_length).unsqueeze(1).repeat(
                1, sequence_length * self.num_classes) * self.num_classes
            A = torch.where((A >= A2) & (A < A2 + self.num_classes),
                            torch.ones_like(A), torch.zeros_like(A))
            A = A.unsqueeze(0).repeat(batch_size, 1,
                                      1).to(logits.device).float()
            b = torch.ones_like(A[:, :, 0])
            probs = QPFunction()(Q, p, torch.autograd.Variable(torch.Tensor()),
                                 torch.autograd.Variable(torch.Tensor()), A, b)
            probs = probs.view(batch_size, sequence_length, self.num_classes)
            """logits_shape = logits.shape
            logits = torch.where(mask.bool().unsqueeze(-1).repeat(1, 1, logits.shape[-1]), logits, logits-10000)
            max_sequence_length = min([l for l in self.lengths if l >= sequence_length])
            if max_sequence_length > logits_shape[1]:
                logits = torch.cat((logits, torch.zeros((batch_size, max_sequence_length-logits_shape[1], logits_shape[2])).to(logits.device)), dim=1)
            lp_layer = self._layer_list[self.length_map[max_sequence_length]]
            probs, = lp_layer(logits)
            print(torch.isnan(probs).any())
            if max_sequence_length > logits_shape[1]:
                probs = probs[:,:logits_shape[1],:]"""
            logits = (torch.nn.functional.relu(probs) + 1e-4).log()
        if self._lpsmap:
            if self._lpsmap_core_only:
                all_logits = logits
            else:
                all_logits = torch.cat((logits, 0.5 * torch.ones(
                    (batch_size, 1, logits.shape[-1])).to(logits.device)),
                                       dim=1)
            probs = []
            for i in range(batch_size):
                if self.constrain_crf_decoding:
                    unaries = logits[i, :, :].view(-1).cpu()
                    additionals = self.crf.transitions.view(-1).repeat(
                        sequence_length) + 10000 * (
                            self.crf._constraint_mask[:-2, :-2] -
                            1).view(-1).repeat(sequence_length)
                    start_transitions = self.crf.start_transitions + 10000 * (
                        self.crf._constraint_mask[-2, :-2] - 1)
                    end_transitions = self.crf.start_transitions + 10000 * (
                        self.crf._constraint_mask[-1, :-2] - 1)
                    additionals = torch.cat(
                        (additionals, start_transitions, end_transitions),
                        dim=0).cpu()
                    fg = TorchFactorGraph()
                    x = fg.variable_from(unaries)
                    f = PFactorSequence()

                    f.initialize(
                        [self.num_classes for _ in range(sequence_length)])
                    factor = TorchOtherFactor(f, x, additionals)
                    fg.add(factor)
                    # add budget constraint for each state
                    for state in self._core_roles:
                        vars_state = x[state::self.num_classes]
                        fg.add(AtMostOne(vars_state))
                    # solve SparseMAP
                    fg.solve(max_iter=200)
                    probs.append(
                        unaries.to(logits.device).view(sequence_length,
                                                       self.num_classes))
                else:
                    fg = TorchFactorGraph()
                    x = fg.variable_from(all_logits[i, :, :].cpu())
                    for j in range(sequence_length):
                        fg.add(Xor(x[j, :]))
                    for j in self._core_roles:
                        fg.add(AtMostOne(x[:sequence_length, j]))
                    if not self._lpsmap_core_only:
                        full_sequence = list(range(sequence_length))
                        base_roles = set([
                            second
                            for (_, second) in self._r_roles + self._c_roles
                        ])
                        """for (r_role, base_role) in self._r_roles+self._c_roles:
                            for j in range(sequence_length):
                                fg.add(Imply(x[full_sequence+[j],[base_role]*sequence_length+[r_role]], negated=[True]*(sequence_length+1)))"""
                        for base_role in base_roles:
                            fg.add(OrOut(x[:, base_role]))
                        for (r_role,
                             base_role) in self._r_roles + self._c_roles:
                            fg.add(OrOut(x[:, r_role]))
                            fg.add(
                                Or(x[[sequence_length, sequence_length],
                                     [r_role, base_role]],
                                   negated=[True, False]))
                    max_iter = 100
                    if not self._lpsmap_core_only:
                        max_iter = min(max_iter, 400)
                    elif (not self.training) and not self._val_inference:
                        max_iter = min(max_iter, 200)
                    fg.solve(max_iter=max_iter)
                    probs.append(x.value[:sequence_length, :].contiguous().to(
                        logits.device))
            class_probabilities = torch.stack(probs)
            # class_probabilities = self.lpsmap(logits)
            max_seq_length = 200
            # if self.lpsmap is None:
            """with torch.no_grad():
                # self.lpsmap = LpSparseMap(num_rows=sequence_length, num_cols=self.num_classes, batch_size=batch_size, device=logits.device, constraints=[('xor', ('row', list(range(sequence_length)))), ('budget', ('col', self._core_roles))])
                max_iter = 1000
                constraint_types = ["xor", "budget"]
                constraint_dims = ["row", "col"]
                constraint_sets = [list(range(sequence_length)), self._core_roles]
                class_probabilities = lpsmap(logits, constraint_types, constraint_dims, constraint_sets, max_iter)
                # if max_seq_length > sequence_length:
                #     logits = torch.cat((logits, -9999.*torch.ones((batch_size, max_seq_length-sequence_length, self.num_classes)).to(logits.device)), dim=1)
                # class_probabilities = self.lpsmap.solve(logits, max_iter=max_iter)"""
            # logits = (class_probabilities+1e-4).log()
        else:
            reshaped_log_probs = logits.view(-1, self.num_classes)
            class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
                [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            # print(mask.shape, tags.shape, logits.shape, tags.max(), tags.min())
            if self._lpsmap:
                loss = LpsmapLoss.apply(logits, class_probabilities, tags,
                                        mask)
                # tags_1hot = torch.zeros_like(class_probabilities).scatter_(2, tags.unsqueeze(-1), torch.ones_like(class_probabilities))
                # loss = -(tags_1hot*class_probabilities*mask.unsqueeze(-1).repeat(1, 1, class_probabilities.shape[-1])).sum()
            elif self.constrain_crf_decoding:
                loss = -self.crf(logits, tags, mask)
            else:
                loss = sequence_cross_entropy_with_logits(
                    logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                if self.constrain_crf_decoding and not self._lpsmap:
                    batch_conll_predicted_tags = [
                        convert_bio_tags_to_conll_format([
                            self.vocab.get_token_from_index(
                                tag, namespace=self._label_namespace)
                            for tag in seq
                        ]) for (seq, _) in self.crf.viterbi_tags(logits, mask)
                    ]
                else:
                    batch_conll_predicted_tags = [
                        convert_bio_tags_to_conll_format(tags)
                        for tags in batch_bio_predicted_tags
                    ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print(batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
            output_dict["gold_tags"] = [x["gold_tags"] for x in metadata]
        return output_dict

    def lpsmap(self, scores: torch.Tensor):
        batch_size, sequence_length, num_classes = scores.shape
        C_indices = torch.arange(sequence_length * num_classes).to(
            scores.device)
        batch_range = torch.arange(batch_size).to(scores.device)
        C_indices = torch.cat(
            (batch_range.unsqueeze(1).repeat(
                1, C_indices.numel()).view(-1).unsqueeze(0),
             C_indices.repeat(batch_size).unsqueeze(0).repeat(2, 1)),
            dim=0)
        C_values = torch.ones_like(C_indices[0, :]).float()
        C = torch.sparse.FloatTensor(
            C_indices, C_values,
            torch.Size([
                batch_size, sequence_length * num_classes,
                sequence_length * num_classes
            ])).to(scores.device)
        delta = torch.ones_like(scores).view(batch_size, -1, 1)
        M = C.clone()
        D = torch.bmm(
            C, delta)  # [batch_size, num_factors*num_variables_per_factor, 1]
        D_inverse = 1. / D
        D_inverse_per_factor = D_inverse.view(
            -1, num_classes,
            1)  # [batch_size*num_factors, num_variables_per_factor, 1]
        # D_block = torch.diag_embed(D)
        # D_inverse_block = torch.diag_embed(D_inverse)
        # M_tilde = torch.bmm(M, D_inverse).t()
        mu = torch.ones_like(scores).view(batch_size, -1, 1)
        lambd = torch.zeros(
            (batch_size * sequence_length, num_classes, 1)).to(scores.device)
        gamma = 1
        T = 10
        # Reshape D, eta, and C_tilde so that a batch has only the variables for a single factor
        eta = scores.view(batch_size * sequence_length, num_classes, 1)
        D = D.view(batch_size * sequence_length, num_classes, 1)
        D_eta_product = D * eta  # [batch_size*num_factors, num_variables_per_factor]
        # C_tilde = C_tilde.view(batch_size*sequence_length, num_classes, sequence_length*num_classes)
        range_per_factor = torch.arange(num_classes).to(
            scores.device).unsqueeze(0).repeat(batch_size * sequence_length,
                                               1).float()
        eps = 1e-6
        for t in range(T):
            Cmu = torch.bmm(C, mu).view(
                -1, num_classes,
                1)  # [batch_size*num_factors, num_variables_per_factor, 1]
            eta_tilde = (
                D_eta_product - lambd + gamma * D_inverse_per_factor * Cmu) / (
                    gamma + 1
                )  # [batch_size*num_factors, num_variables_per_factor, 1]
            # Procedure for solving XOR QP according to Duchi et al. 2008
            # TODO: adjust this for when Mf_tilde \neq I (i.e. when there are more factors)
            eta_tilde_sorted, eta_indices = eta_tilde.view(
                batch_size * sequence_length, -1
            ).sort(
                dim=-1, descending=True
            )  # [batch_size*num_factors, num_variables_per_factor], [batch_size*num_factors, num_variables_per_factor]
            eta_tilde_cumsum = torch.cumsum(
                eta_tilde_sorted,
                dim=1)  # [batch_size*num_factors, num_variables_per_factor]
            eta_tilde_cumsum = (eta_tilde_cumsum - 1) / (
                range_per_factor + 1
            )  # [batch_size*num_factors, num_variables_per_factor]
            rho = ((eta_tilde_sorted - eta_tilde_cumsum > 0).float() *
                   (1 + range_per_factor)).argmax(
                       1)  # [batch_size*num_factors]
            theta = torch.gather(
                eta_tilde_cumsum, 1,
                rho.unsqueeze(-1)).repeat(1, num_classes).unsqueeze(
                    -1)  #[batch_size*num_factors, num_variables_per_factor, 1]
            p = torch.nn.functional.relu(
                eta_tilde - theta
            )  # [batch_size*num_factors, num_variables_per_factor = num_assignments_per_factor, 1]
            Mp = torch.bmm(M, p.view(
                batch_size, -1,
                1))  # [batch_size, num_factors*num_variables_per_factor, 1]
            Dinv_Mp = D_inverse * Mp  # [batch_size, num_factors*num_variables_per_factor, 1]
            Dinv2_Mp = D_inverse * Dinv_Mp  # [batch_size, num_factors*num_variables_per_factor, 1]
            mu_new = torch.bmm(C.transpose(1, 2),
                               Dinv2_Mp)  # [batch_size, num_variables, 1]
            lambd_diff = D_inverse * torch.bmm(
                C, mu_new
            ) - Dinv_Mp  # [batch_size, num_factors*num_variables_per_factor, 1]
            lambd += gamma * lambd_diff.view(
                -1, num_classes,
                1)  # [batch_size*num_factors, num_variables_per_factor]
            mu_diff_norm = torch.norm((mu_new - mu).squeeze(-1),
                                      dim=1)  # [batch_size]
            lambd_diff_norm = torch.norm(lambd_diff.squeeze(-1),
                                         dim=1)  # [batch_size]
            if mu_diff_norm.max().item() < eps and lambd_diff_norm.max().item(
            ) < eps:
                break
            mu = mu_new
        mu = mu.view(batch_size, sequence_length, num_classes)
        return mu

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.

        NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
        decoding produces low quality output if you decode on top of word representations directly,
        because the model gets confused by the 'missing' positions (which is sensible as it is trained
        to perform tagging on wordpieces, not words).

        Secondly, it's important that the indices we use to recover words from the wordpieces are the
        start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
        tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
        when we select out the word tags from the wordpiece tags. This happens in the case that a word
        is split into multiple word pieces, and then we take the last tag of the word, which might
        correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        wordpiece_tags = []
        word_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        # **************** Different ********************
        # We add in the offsets here so we can compute the un-wordpieced tags.
        for predictions, length, offsets in zip(
                predictions_list, sequence_lengths,
                output_dict["wordpiece_offsets"]):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(
                    x, namespace=self._label_namespace)
                for x in max_likelihood_sequence
            ]

            wordpiece_tags.append(tags)
            if isinstance(self.bert_model,
                          PretrainedTransformerMismatchedEmbedder):
                word_tags.append(tags)
            else:
                word_tags.append([tags[i] for i in offsets])
            # print(word_tags)
        output_dict["wordpiece_tags"] = wordpiece_tags
        output_dict["tags"] = word_tags
        return output_dict

    def get_metrics(self, reset: bool = False):
        if self.ignore_span_metric:
            # Return an empty dictionary if ignoring the
            # span metric
            return {}

        else:
            metric_dict = self.span_metric.get_metric(reset=reset)

            # This can be a lot of metrics, as there are 3 per class.
            # we only really care about the overall metrics, so we filter for them here.
            return {x: y for x, y in metric_dict.items() if "overall" in x}

    def get_viterbi_pairwise_potentials(self):
        """
        Generate a matrix of pairwise transition potentials for the BIO labels.
        The only constraint implemented here is that I-XXX labels must be preceded
        by either an identical I-XXX tag or a B-XXX tag. In order to achieve this
        constraint, pairs of labels which do not satisfy this constraint have a
        pairwise potential of -inf.

        # Returns

        transition_matrix : `torch.Tensor`
            A `(num_labels, num_labels)` matrix of pairwise potentials.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary(
            self._label_namespace)
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])

        for i, previous_label in all_labels.items():
            for j, label in all_labels.items():
                # I labels can only be preceded by themselves or
                # their corresponding B tag.
                if i != j and label[
                        0] == "I" and not previous_label == "B" + label[1:]:
                    transition_matrix[i, j] = float("-inf")
        return transition_matrix

    def get_start_transitions(self):
        """
        In the BIO sequence, we cannot start the sequence with an I-XXX tag.
        This transition sequence is passed to viterbi_decode to specify this constraint.

        # Returns

        start_transitions : `torch.Tensor`
            The pairwise potentials between a START token and
            the first token of the sequence.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary(
            self._label_namespace)
        num_labels = len(all_labels)

        start_transitions = torch.zeros(num_labels)

        for i, label in all_labels.items():
            if label[0] == "I":
                start_transitions[i] = float("-inf")

        return start_transitions

    default_predictor = "semantic_role_labeling"
Exemple #12
0
class SrlBert(Model):
    """

    A BERT based model [Simple BERT Models for Relation Extraction and Semantic Role Labeling (Shi et al, 2019)]
    (https://arxiv.org/abs/1904.05255) with some modifications (no additional parameters apart from a linear
    classification layer), which is currently the state-of-the-art single model for English PropBank SRL
    (Newswire sentences).

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    model : `Union[str, BertModel]`, required.
        A string describing the BERT model to load or an already constructed BertModel.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    label_smoothing : `float`, optional (default = `0.0`)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    ignore_span_metric : `bool`, optional (default = `False`)
        Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
    srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
        The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
        which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric
        initializer(self)

    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containing the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        logits = self.tag_projection_layer(embedded_text_input)

        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.

        NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
        decoding produces low quality output if you decode on top of word representations directly,
        because the model gets confused by the 'missing' positions (which is sensible as it is trained
        to perform tagging on wordpieces, not words).

        Secondly, it's important that the indices we use to recover words from the wordpieces are the
        start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
        tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
        when we select out the word tags from the wordpiece tags. This happens in the case that a word
        is split into multiple word pieces, and then we take the last tag of the word, which might
        correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        wordpiece_tags = []
        word_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        # **************** Different ********************
        # We add in the offsets here so we can compute the un-wordpieced tags.
        for predictions, length, offsets in zip(
                predictions_list, sequence_lengths,
                output_dict["wordpiece_offsets"]):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(x, namespace="labels")
                for x in max_likelihood_sequence
            ]

            wordpiece_tags.append(tags)
            word_tags.append([tags[i] for i in offsets])
        output_dict["wordpiece_tags"] = wordpiece_tags
        output_dict["tags"] = word_tags
        return output_dict

    def get_metrics(self, reset: bool = False):
        if self.ignore_span_metric:
            # Return an empty dictionary if ignoring the
            # span metric
            return {}

        else:
            metric_dict = self.span_metric.get_metric(reset=reset)

            # This can be a lot of metrics, as there are 3 per class.
            # we only really care about the overall metrics, so we filter for them here.
            return {x: y for x, y in metric_dict.items() if "overall" in x}

    def get_viterbi_pairwise_potentials(self):
        """
        Generate a matrix of pairwise transition potentials for the BIO labels.
        The only constraint implemented here is that I-XXX labels must be preceded
        by either an identical I-XXX tag or a B-XXX tag. In order to achieve this
        constraint, pairs of labels which do not satisfy this constraint have a
        pairwise potential of -inf.

        # Returns

        transition_matrix : `torch.Tensor`
            A `(num_labels, num_labels)` matrix of pairwise potentials.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])

        for i, previous_label in all_labels.items():
            for j, label in all_labels.items():
                # I labels can only be preceded by themselves or
                # their corresponding B tag.
                if i != j and label[
                        0] == "I" and not previous_label == "B" + label[1:]:
                    transition_matrix[i, j] = float("-inf")
        return transition_matrix

    def get_start_transitions(self):
        """
        In the BIO sequence, we cannot start the sequence with an I-XXX tag.
        This transition sequence is passed to viterbi_decode to specify this constraint.

        # Returns

        start_transitions : `torch.Tensor`
            The pairwise potentials between a START token and
            the first token of the sequence.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)

        start_transitions = torch.zeros(num_labels)

        for i, label in all_labels.items():
            if label[0] == "I":
                start_transitions[i] = float("-inf")

        return start_transitions

    default_predictor = "semantic_role_labeling"
Exemple #13
0
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        mention_feedforward: FeedForward,
        context_layer: Seq2SeqEncoder = None,
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        max_span_width: int = 30,
        feature_size: int = 10,
        spans_per_word: float = 100,
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("span_labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric

        self._mention_feedforward = TimeDistributed(mention_feedforward)
        self._mention_scorer = TimeDistributed(
            torch.nn.Linear(mention_feedforward.get_output_dim(), 1))

        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=self.bert_model.config.hidden_size)
        self.span_representation_dim = self._attentive_span_extractor.get_output_dim(
        )
        self._context_layer = context_layer
        if context_layer is not None:
            self._endpoint_span_extractor = EndpointSpanExtractor(
                context_layer.get_output_dim(),
                combination="x,y",
                num_width_embeddings=max_span_width,
                span_width_embedding_dim=feature_size,
                bucket_widths=False,
            )
            self.span_representation_dim = self._endpoint_span_extractor.get_output_dim(
            )

        self.hidden_layer = torch.nn.Sequential(
            torch.nn.Linear(self.span_representation_dim +
                            self.bert_model.config.hidden_size,
                            self.span_representation_dim,
                            bias=False), torch.nn.ReLU())
        self.output_layer = torch.nn.Linear(self.span_representation_dim,
                                            self.num_classes - 1,
                                            bias=False)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self._bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        initializer(self)
Exemple #14
0
class SrlE2e(Model):
    """

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    model : `Union[str, BertModel]`, required.
        A string describing the BERT model to load or an already constructed BertModel.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    label_smoothing : `float`, optional (default = `0.0`)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    ignore_span_metric : `bool`, optional (default = `False`)
        Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
    srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
        The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
        which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        mention_feedforward: FeedForward,
        context_layer: Seq2SeqEncoder = None,
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        max_span_width: int = 30,
        feature_size: int = 10,
        spans_per_word: float = 100,
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("span_labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric

        self._mention_feedforward = TimeDistributed(mention_feedforward)
        self._mention_scorer = TimeDistributed(
            torch.nn.Linear(mention_feedforward.get_output_dim(), 1))

        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=self.bert_model.config.hidden_size)
        self.span_representation_dim = self._attentive_span_extractor.get_output_dim(
        )
        self._context_layer = context_layer
        if context_layer is not None:
            self._endpoint_span_extractor = EndpointSpanExtractor(
                context_layer.get_output_dim(),
                combination="x,y",
                num_width_embeddings=max_span_width,
                span_width_embedding_dim=feature_size,
                bucket_widths=False,
            )
            self.span_representation_dim = self._endpoint_span_extractor.get_output_dim(
            )

        self.hidden_layer = torch.nn.Sequential(
            torch.nn.Linear(self.span_representation_dim +
                            self.bert_model.config.hidden_size,
                            self.span_representation_dim,
                            bias=False), torch.nn.ReLU())
        self.output_layer = torch.nn.Linear(self.span_representation_dim,
                                            self.num_classes - 1,
                                            bias=False)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self._bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        initializer(self)

    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        sentence_end: torch.LongTensor,
        spans: torch.LongTensor,
        span_labels: torch.LongTensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        start = time.time()
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            # token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            bert_embeddings, spans)

        if self._context_layer is not None:
            contextualized_embeddings = self._context_layer(
                embedded_text_input, mask)
            # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
            endpoint_span_embeddings = self._endpoint_span_extractor(
                contextualized_embeddings, spans)

            # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
            # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)
            span_embeddings = endpoint_span_embeddings
        else:
            span_embeddings = attended_span_embeddings

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * sequence_length))
        num_spans = spans.shape[1]
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep)
        verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat(
            1, 1, embedded_text_input.shape[-1])
        verb_embeddings = torch.gather(embedded_text_input, 1, verb_index)
        assert len(
            verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1
        verb_embeddings = verb_embeddings.squeeze(1)
        # print(verb_indicator.sum(1, keepdim=True) > 0)
        verb_embeddings = torch.where(
            (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                1, verb_embeddings.shape[-1]), verb_embeddings,
            torch.zeros_like(verb_embeddings))
        # print(verb_embeddings)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, spans.shape[1])
        span_embeddings = util.batched_index_select(span_embeddings,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)
        top_span_labels = util.batched_index_select(
            span_labels.unsqueeze(-1), top_span_indices,
            flat_top_span_indices).squeeze(-1)
        concatenated_span_embeddings = torch.cat(
            (span_embeddings, verb_embeddings.unsqueeze(1).repeat(
                1, span_embeddings.shape[1], 1)),
            dim=2)
        # print(concatenated_span_embeddings[:,:,:])
        hidden = self.hidden_layer(concatenated_span_embeddings)
        # print(hidden[1,:,:])
        # print(top_span_indices)
        # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])])
        # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels"))
        predictions = self.output_layer(hidden)
        # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)
        predictions = torch.cat(
            (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1)
        # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1))

        output_dict = {}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]),
                                  top_span_labels.view(-1)) *
                    top_span_mask.float().view(-1)
                    ).sum() / top_span_mask.float().sum()
            # print(top_span_labels)
            # print(predictions.argmax(-1))
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.get_tags(
                    top_spans, predictions, mask.shape[1], top_span_mask,
                    output_dict)
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print('G', batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict

    def get_tags(self, spans, logits, sequence_length, span_mask, output_dict):
        predicted_tag_ids = logits.argmax(2)
        predicted_tags = []
        for i in range(spans.shape[0]):
            sequence = ["O" for _ in range(sequence_length)]
            for j in range(spans.shape[1]):
                if span_mask[i, j].item() == 0:
                    continue
                tag = predicted_tag_ids[i, j].item()
                if tag != self.vocab.get_token_index("O",
                                                     namespace="span_labels"):
                    start = spans[i, j, 0].item()
                    end = spans[i, j, 1].item()
                    if all([el == "O" for el in sequence[start:end + 1]]):
                        tag = self.vocab.get_token_from_index(
                            tag, namespace="span_labels")
                        sequence[start] = "B-" + tag
                        for index in range(start + 1, end + 1):
                            sequence[index] = "I-" + tag
            predicted_tags.append(
                [sequence[ind] for ind in output_dict["wordpiece_offsets"][i]])
        print(predicted_tags)
        return predicted_tags

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.

        NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
        decoding produces low quality output if you decode on top of word representations directly,
        because the model gets confused by the 'missing' positions (which is sensible as it is trained
        to perform tagging on wordpieces, not words).

        Secondly, it's important that the indices we use to recover words from the wordpieces are the
        start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
        tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
        when we select out the word tags from the wordpiece tags. This happens in the case that a word
        is split into multiple word pieces, and then we take the last tag of the word, which might
        correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        wordpiece_tags = []
        word_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        # **************** Different ********************
        # We add in the offsets here so we can compute the un-wordpieced tags.
        for predictions, length, offsets in zip(
                predictions_list, sequence_lengths,
                output_dict["wordpiece_offsets"]):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(x, namespace="labels")
                for x in max_likelihood_sequence
            ]

            wordpiece_tags.append(tags)
            word_tags.append([tags[i] for i in offsets])
        output_dict["wordpiece_tags"] = wordpiece_tags
        output_dict["tags"] = word_tags
        return output_dict

    def get_metrics(self, reset: bool = False):
        if self.ignore_span_metric:
            # Return an empty dictionary if ignoring the
            # span metric
            return {}

        else:
            metric_dict = self.span_metric.get_metric(reset=reset)

            # This can be a lot of metrics, as there are 3 per class.
            # we only really care about the overall metrics, so we filter for them here.
            return {x: y for x, y in metric_dict.items() if "overall" in x}

    def get_viterbi_pairwise_potentials(self):
        """
        Generate a matrix of pairwise transition potentials for the BIO labels.
        The only constraint implemented here is that I-XXX labels must be preceded
        by either an identical I-XXX tag or a B-XXX tag. In order to achieve this
        constraint, pairs of labels which do not satisfy this constraint have a
        pairwise potential of -inf.

        # Returns

        transition_matrix : `torch.Tensor`
            A `(num_labels, num_labels)` matrix of pairwise potentials.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])

        for i, previous_label in all_labels.items():
            for j, label in all_labels.items():
                # I labels can only be preceded by themselves or
                # their corresponding B tag.
                if i != j and label[
                        0] == "I" and not previous_label == "B" + label[1:]:
                    transition_matrix[i, j] = float("-inf")
        return transition_matrix

    def get_start_transitions(self):
        """
        In the BIO sequence, we cannot start the sequence with an I-XXX tag.
        This transition sequence is passed to viterbi_decode to specify this constraint.

        # Returns

        start_transitions : `torch.Tensor`
            The pairwise potentials between a START token and
            the first token of the sequence.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)

        start_transitions = torch.zeros(num_labels)

        for i, label in all_labels.items():
            if label[0] == "I":
                start_transitions[i] = float("-inf")

        return start_transitions

    default_predictor = "semantic_role_labeling"