def test_widths_are_embedded_correctly(self):
        input_dim = 7
        max_span_width = 5
        span_width_embedding_dim = 3
        output_dim = input_dim + span_width_embedding_dim
        extractor = SelfAttentiveSpanExtractor(
            input_dim=input_dim,
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=span_width_embedding_dim,
        )
        assert extractor.get_output_dim() == output_dim
        assert extractor.get_input_dim() == input_dim

        sequence_tensor = torch.randn([2, max_span_width, input_dim])
        indices = torch.LongTensor([[[1, 3], [0, 4], [0, 0]],
                                    [[0, 2], [1, 4],
                                     [2, 2]]])  # smaller span tests masking.
        span_representations = extractor(sequence_tensor, indices)
        assert list(span_representations.size()) == [2, 3, output_dim]

        width_embeddings = extractor._span_width_embedding.weight.data.numpy()
        widths_minus_one = indices[..., 1] - indices[..., 0]
        for element in range(indices.size(0)):
            for span in range(indices.size(1)):
                width = widths_minus_one[element, span].item()
                width_embedding = span_representations[element, span,
                                                       input_dim:]
                numpy.testing.assert_array_almost_equal(
                    width_embedding.data.numpy(), width_embeddings[width])
Esempio n. 2
0
    def test_attention_is_normalised_correctly(self):
        input_dim = 7
        sequence_tensor = torch.randn([2, 5, input_dim])
        extractor = SelfAttentiveSpanExtractor(input_dim=input_dim)
        assert extractor.get_output_dim() == input_dim
        assert extractor.get_input_dim() == input_dim

        # In order to test the attention, we'll make the weight which computes the logits
        # zero, so the attention distribution is uniform over the sentence. This lets
        # us check that the computed spans are just the averages of their representations.
        extractor._global_attention._module.weight.data.fill_(0.0)
        extractor._global_attention._module.bias.data.fill_(0.0)

        indices = torch.LongTensor(
            [[[1, 3], [2, 4]], [[0, 2], [3, 4]]]
        )  # smaller span tests masking.
        span_representations = extractor(sequence_tensor, indices)
        assert list(span_representations.size()) == [2, 2, input_dim]

        # First element in the batch.
        batch_element = 0
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 1:4, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 2:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())
        # Now the second element in the batch.
        batch_element = 1
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 0:3, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 3:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())

        # Now test the case in which we have some masked spans in our indices.
        indices_mask = torch.BoolTensor([[True, True], [True, False]])
        span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)

        # First element in the batch.
        batch_element = 0
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 1:4, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 2:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())
        # Now the second element in the batch.
        batch_element = 1
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 0:3, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span was masked, so should be completely zero.
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), numpy.zeros([input_dim]))
    def test_attention_is_normalised_correctly(self):
        input_dim = 7
        sequence_tensor = torch.randn([2, 5, input_dim])
        extractor = SelfAttentiveSpanExtractor(input_dim=input_dim)
        assert extractor.get_output_dim() == input_dim
        assert extractor.get_input_dim() == input_dim

        # In order to test the attention, we'll make the weight which computes the logits
        # zero, so the attention distribution is uniform over the sentence. This lets
        # us check that the computed spans are just the averages of their representations.
        extractor._global_attention._module.weight.data.fill_(0.0)
        extractor._global_attention._module.bias.data.fill_(0.0)

        indices = torch.LongTensor([[[1, 3],
                                     [2, 4]],
                                    [[0, 2],
                                     [3, 4]]]) # smaller span tests masking.
        span_representations = extractor(sequence_tensor, indices)
        assert list(span_representations.size()) == [2, 2, input_dim]

        # First element in the batch.
        batch_element = 0
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 1:4, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 2:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())
        # Now the second element in the batch.
        batch_element = 1
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 0:3, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 3:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())


        # Now test the case in which we have some masked spans in our indices.
        indices_mask = torch.LongTensor([[1, 1], [1, 0]])
        span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)

        # First element in the batch.
        batch_element = 0
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 1:4, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 2:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())
        # Now the second element in the batch.
        batch_element = 1
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 0:3, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span was masked, so should be completely zero.
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), numpy.zeros([input_dim]))
Esempio n. 4
0
class DyGIE(Model):
    """
    TODO(dwadden) document me.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``text`` ``TextField`` we get as input to the model.
    context_layer : ``Seq2SeqEncoder``
        This layer incorporates contextual information for each word in the document.
    feature_size: ``int``
        The embedding size for all the embedded features, such as distances or span widths.
    submodule_params: ``TODO(dwadden)``
        A nested dictionary specifying parameters to be passed on to initialize submodules.
    max_span_width: ``int``
        The maximum width of candidate spans.
    max_trigger_span_width: ``int``
        The maximum width of candidate trigger spans.
    target_task: ``str``:
        The task used to make early stopping decisions.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    module_initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the individual modules.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    display_metrics: ``List[str]``. A list of the metrics that should be printed out during model
        training.
    """
    def __init__(
            self,
            vocab: Vocabulary,
            embedder: TextFieldEmbedder,
            context_layer: Seq2SeqEncoder,
            modules,  # TODO(dwadden) Add type.
            feature_size: int,
            max_span_width: int,
            max_trigger_span_width: int,
            target_task: str,
            feedforward_params: Dict[str, Union[int, float]],
            loss_weights: Dict[str, float],
            lexical_dropout: float = 0.2,
            use_attentive_span_extractor: bool = False,
            initializer: InitializerApplicator = InitializerApplicator(),
            module_initializer: InitializerApplicator = InitializerApplicator(
            ),
            regularizer: Optional[RegularizerApplicator] = None,
            display_metrics: List[str] = None) -> None:
        super(DyGIE, self).__init__(vocab, regularizer)

        ####################

        # Create span extractor.
        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._endpoint_trigger_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_trigger_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)

        ####################
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        if use_attentive_span_extractor:
            self._attentive_span_extractor = SelfAttentiveSpanExtractor(
                input_dim=context_layer.get_output_dim())
        else:
            self._attentive_span_extractor = None

        # Set parameters.
        self._embedder = embedder
        self._context_layer = context_layer
        self._loss_weights = loss_weights
        self._max_span_width = max_span_width
        self._max_trigger_span_width = max_trigger_span_width
        self._display_metrics = self._get_display_metrics(target_task)

        trigger_emb_dim = self._endpoint_trigger_span_extractor.get_output_dim(
        )
        span_emb_dim = self._endpoint_span_extractor.get_output_dim()

        if self._attentive_span_extractor is not None:
            span_emb_dim += self._attentive_span_extractor.get_output_dim()
            trigger_emb_dim += self._attentive_span_extractor.get_output_dim()

        ####################

        # Create submodules.

        modules = Params(modules)

        # Helper function to create feedforward networks.
        def make_feedforward(input_dim):
            return FeedForward(input_dim=input_dim,
                               num_layers=feedforward_params["num_layers"],
                               hidden_dims=feedforward_params["hidden_dims"],
                               activations=torch.nn.ReLU(),
                               dropout=feedforward_params["dropout"])

        # Submodules

        self._ner = NERTagger.from_params(vocab=vocab,
                                          make_feedforward=make_feedforward,
                                          span_emb_dim=span_emb_dim,
                                          feature_size=feature_size,
                                          params=modules.pop("ner"))

        self._coref = CorefResolver.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("coref"))

        self._relation = RelationExtractor.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("relation"))

        self._events = EventExtractor.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            text_emb_dim=self._embedder.get_output_dim(),
            trigger_emb_dim=trigger_emb_dim,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("events"))

        ####################

        # Initialize text embedder and all submodules
        for module in [self._ner, self._coref, self._relation, self._events]:
            module_initializer(module)

        initializer(self)

    @staticmethod
    def _get_display_metrics(target_task):
        """
        The `target` is the name of the task used to make early stopping decisions. Show metrics
        related to this task.
        """
        lookup = {
            "ner": [
                f"MEAN__{name}"
                for name in ["ner_precision", "ner_recall", "ner_f1"]
            ],
            "relation": [
                f"MEAN__{name}" for name in
                ["relation_precision", "relation_recall", "relation_f1"]
            ],
            "coref": [
                "coref_precision", "coref_recall", "coref_f1",
                "coref_mention_recall"
            ],
            "events":
            [f"MEAN__{name}" for name in ["trig_class_f1", "arg_class_f1"]]
        }
        if target_task not in lookup:
            raise ValueError(
                f"Invalied value {target_task} has been given as the target task."
            )
        return lookup[target_task]

    @staticmethod
    def _debatch(x):
        # TODO(dwadden) Get rid of this when I find a better way to do it.
        return x if x is None else x.squeeze(0)

    @overrides
    def forward(self,
                text,
                trigger_spans,
                spans,
                metadata,
                ner_labels=None,
                coref_labels=None,
                relation_labels=None,
                trigger_labels=None,
                argument_labels=None):
        """
        TODO(dwadden) change this.
        """
        # In AllenNLP, AdjacencyFields are passed in as floats. This fixes it.
        if relation_labels is not None:
            relation_labels = relation_labels.long()
        if argument_labels is not None:
            argument_labels = argument_labels.long()

        # TODO(dwadden) Multi-document minibatching isn't supported yet. For now, get rid of the
        # extra dimension in the input tensors. Will return to this once the model runs.
        if len(metadata) > 1:
            raise NotImplementedError(
                "Multi-document minibatching not yet supported.")

        metadata = metadata[0]
        spans = self._debatch(spans)  # (n_sents, max_n_spans, 2)
        trigger_spans = self._debatch(
            trigger_spans)  # (n_sents, max_n_spans, 2)
        ner_labels = self._debatch(ner_labels)  # (n_sents, max_n_spans)
        coref_labels = self._debatch(coref_labels)  #  (n_sents, max_n_spans)
        relation_labels = self._debatch(
            relation_labels)  # (n_sents, max_n_spans, max_n_spans)
        trigger_labels = self._debatch(trigger_labels)  # TODO(dwadden)
        argument_labels = self._debatch(argument_labels)  # TODO(dwadden)

        # Encode using BERT, then debatch.
        # Since the data are batched, we use `num_wrapping_dims=1` to unwrap the document dimension.
        # (1, n_sents, max_sententence_length, embedding_dim)

        # TODO(dwadden) Deal with the case where the input is longer than 512.
        text_embeddings = self._embedder(text, num_wrapping_dims=1)
        # (n_sents, max_n_wordpieces, embedding_dim)
        text_embeddings = self._debatch(text_embeddings)
        # apply lexical dropout
        text_embeddings = self._lexical_dropout(text_embeddings)

        # (n_sents, max_sentence_length)
        text_mask = self._debatch(
            util.get_text_field_mask(text, num_wrapping_dims=1).float())
        sentence_lengths = text_mask.sum(dim=1).long()  # (n_sents)

        # contextualize text embeddings
        text_embeddings = self._context_layer(text_embeddings, text_mask)

        # Create spans, i.e. span_embeddings, masks and span_indices
        span_mask = (spans[:, :, 0] >= 0).float()  # (n_sents, max_n_spans)
        # SpanFields return -1 when they are used as padding. As we do some comparisons based on
        # span widths when we attend over the span representations that we generate from these
        # indices, we need them to be <= 0. This is only relevant in edge cases where the number of
        # spans we consider after the pruning stage is >= the total number of spans, because in this
        # case, it is possible we might consider a masked span.
        spans = F.relu(spans.float()).long()  # (n_sents, max_n_spans, 2)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        span_embeddings = self._endpoint_span_extractor(text_embeddings, spans)

        trigger_mask = (trigger_spans[:, :, 0] >= 0).float()
        trigger_spans = F.relu(trigger_spans.float()).long()
        trigger_embeddings = self._endpoint_trigger_span_extractor(
            text_embeddings, trigger_spans)

        # Make attented spans embeddings
        if self._attentive_span_extractor is not None:
            # Shape: (batch_size, num_spans, embedding_size)
            attended_span_embeddings = self._attentive_span_extractor(
                text_embeddings, spans)
            attended_trigger_span_embeddings = self._attentive_span_extractor(
                text_embeddings, trigger_spans)
            # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size)
            span_embeddings = torch.cat(
                [span_embeddings, attended_span_embeddings], -1)
            trigger_embeddings = torch.cat(
                [trigger_embeddings, attended_trigger_span_embeddings], -1)

        # Make calls out to the modules to get results.
        output_coref = {'loss': 0}
        output_ner = {'loss': 0}
        output_relation = {'loss': 0}
        output_events = {'loss': 0}

        # Prune and compute span representations for coreference module
        if self._loss_weights["coref"] > 0 or self._coref.coref_prop > 0:
            output_coref, coref_indices = self._coref.compute_representations(
                spans, span_mask, span_embeddings, sentence_lengths,
                coref_labels, metadata)

        # Propagation of global information to enhance the span embeddings
        if self._coref.coref_prop > 0:
            output_coref = self._coref.coref_propagation(output_coref)
            span_embeddings = self._coref.update_spans(output_coref,
                                                       span_embeddings,
                                                       coref_indices)

        # Make predictions and compute losses for each module
        if self._loss_weights['ner'] > 0:
            output_ner = self._ner(spans, span_mask, span_embeddings,
                                   sentence_lengths, ner_labels, metadata)

        if self._loss_weights['coref'] > 0:
            output_coref = self._coref.predict_labels(output_coref, metadata)

        if self._loss_weights['relation'] > 0:
            output_relation = self._relation(spans, span_mask, span_embeddings,
                                             sentence_lengths, relation_labels,
                                             metadata)

        if self._loss_weights['events'] > 0:
            output_events = self._events(trigger_spans, trigger_mask,
                                         trigger_embeddings, spans, span_mask,
                                         span_embeddings, text_mask,
                                         text_embeddings, sentence_lengths,
                                         trigger_labels, argument_labels,
                                         ner_labels, metadata)

        # Use `get` since there are some cases where the output dict won't have a loss - for
        # instance, when doing prediction.
        loss = (
            self._loss_weights['coref'] * output_coref.get("loss", 0) +
            self._loss_weights['ner'] * output_ner.get("loss", 0) +
            self._loss_weights['relation'] * output_relation.get("loss", 0) +
            self._loss_weights['events'] * output_events.get("loss", 0))

        # Multiply the loss by the weight multiplier for this document.
        weight = metadata.weight if metadata.weight is not None else 1.0
        loss *= torch.tensor(weight)

        output_dict = dict(coref=output_coref,
                           relation=output_relation,
                           ner=output_ner,
                           events=output_events)
        output_dict['loss'] = loss

        output_dict["metadata"] = metadata

        return output_dict

    def update_span_embeddings(self, span_embeddings, span_mask,
                               top_span_embeddings, top_span_mask,
                               top_span_indices):
        # TODO(Ulme) Speed this up by tensorizing

        new_span_embeddings = span_embeddings.clone()
        for sample_nr in range(len(top_span_mask)):
            for top_span_nr, span_nr in enumerate(top_span_indices[sample_nr]):
                if top_span_mask[sample_nr,
                                 top_span_nr] == 0 or span_mask[sample_nr,
                                                                span_nr] == 0:
                    break
                new_span_embeddings[sample_nr,
                                    span_nr] = top_span_embeddings[sample_nr,
                                                                   top_span_nr]
        return new_span_embeddings

    @overrides
    def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]):
        """
        Converts the list of spans and predicted antecedent indices into clusters
        of spans for each element in the batch.

        Parameters
        ----------
        output_dict : ``Dict[str, torch.Tensor]``, required.
            The result of calling :func:`forward` on an instance or batch of instances.

        Returns
        -------
        The same output dictionary, but with an additional ``clusters`` key:

        clusters : ``List[List[List[Tuple[int, int]]]]``
            A nested list, representing, for each instance in the batch, the list of clusters,
            which are in turn comprised of a list of (start, end) inclusive spans into the
            original document.
        """

        doc = copy.deepcopy(output_dict["metadata"])

        if self._loss_weights["coref"] > 0:
            # TODO(dwadden) Will need to get rid of the [0] when batch training is enabled.
            decoded_coref = self._coref.make_output_human_readable(
                output_dict["coref"])["predicted_clusters"][0]
            sentences = doc.sentences
            sentence_starts = [sent.sentence_start for sent in sentences]
            predicted_clusters = [
                document.Cluster(entry, i, sentences, sentence_starts)
                for i, entry in enumerate(decoded_coref)
            ]
            doc.predicted_clusters = predicted_clusters
            # TODO(dwadden) update the sentences with cluster information.

        if self._loss_weights["ner"] > 0:
            for predictions, sentence in zip(output_dict["ner"]["predictions"],
                                             doc):
                sentence.predicted_ner = predictions

        if self._loss_weights["relation"] > 0:
            for predictions, sentence in zip(
                    output_dict["relation"]["predictions"], doc):
                sentence.predicted_relations = predictions

        if self._loss_weights["events"] > 0:
            for predictions, sentence in zip(
                    output_dict["events"]["predictions"], doc):
                sentence.predicted_events = predictions

        return doc

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        Get all metrics from all modules. For the ones that shouldn't be displayed, prefix their
        keys with an underscore.
        """
        metrics_coref = self._coref.get_metrics(reset=reset)
        metrics_ner = self._ner.get_metrics(reset=reset)
        metrics_relation = self._relation.get_metrics(reset=reset)
        metrics_events = self._events.get_metrics(reset=reset)

        # Make sure that there aren't any conflicting names.
        metric_names = (list(metrics_coref.keys()) + list(metrics_ner.keys()) +
                        list(metrics_relation.keys()) +
                        list(metrics_events.keys()))
        assert len(set(metric_names)) == len(metric_names)
        all_metrics = dict(
            list(metrics_coref.items()) + list(metrics_ner.items()) +
            list(metrics_relation.items()) + list(metrics_events.items()))

        # If no list of desired metrics given, display them all.
        if self._display_metrics is None:
            return all_metrics
        # Otherwise only display the selected ones.
        res = {}
        for k, v in all_metrics.items():
            if k in self._display_metrics:
                res[k] = v
            else:
                new_k = "_" + k
                res[new_k] = v
        return res
Esempio n. 5
0
class SCIIE(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``text`` ``TextField`` we get as input to the model.
    context_layer : ``Seq2SeqEncoder``
        This layer incorporates contextual information for each word in the document.
    mention_feedforward : ``FeedForward``
        This feedforward network is applied to the span representations which is then scored
        by a linear layer.
    antecedent_feedforward: ``FeedForward``
        This feedforward network is applied to pairs of span representation, along with any
        pairwise features, which is then scored by a linear layer.
    feature_size: ``int``
        The embedding size for all the embedded features, such as distances or span widths.
    max_span_width: ``int``
        The maximum width of candidate spans.
    spans_per_word: float, required.
        A multiplier between zero and one which controls what percentage of candidate mention
        spans we retain with respect to the number of words in the document.
    max_antecedents: int, required.
        For each mention which survives the pruning stage, we consider this many antecedents.
    lexical_dropout: ``int``
        The probability of dropping out dimensions of the embedded text.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 embedding_dim: int,
                 feature_size: int,
                 max_span_width: int,
                 spans_per_word: float,
                 lexical_dropout: float = 0.2,
                 mlp_dropout: float = 0.4,
                 embedder_type=None,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SCIIE, self).__init__(vocab, regularizer)
        self.class_num = self.vocab.get_vocab_size('labels')
        word_embeddings = get_embeddings(embedder_type, self.vocab,
                                         embedding_dim, True)
        embedding_dim = word_embeddings.get_output_dim()
        self._text_field_embedder = word_embeddings

        context_layer = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(embedding_dim,
                          feature_size,
                          batch_first=True,
                          bidirectional=True))
        self._context_layer = context_layer

        endpoint_span_extractor_input_dim = context_layer.get_output_dim()
        attentive_span_extractor_input_dim = word_embeddings.get_output_dim()

        self._endpoint_span_extractor = EndpointSpanExtractor(
            endpoint_span_extractor_input_dim,
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=attentive_span_extractor_input_dim)

        # self._span_extractor = PoolingSpanExtractor(embedding_dim,
        #                                             num_width_embeddings=max_span_width,
        #                                             span_width_embedding_dim=feature_size,
        #                                             bucket_widths=False)

        entity_feedforward = FeedForward(
            self._endpoint_span_extractor.get_output_dim() +
            self._attentive_span_extractor.get_output_dim(), 2, 150, F.relu,
            mlp_dropout)
        # entity_feedforward = FeedForward(self._span_extractor.get_output_dim(), 2, 150,
        #                                  F.relu, mlp_dropout)

        feedforward_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(), 1)))
        self._mention_pruner = Pruner(feedforward_scorer)

        self._entity_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(),
                                self.class_num - 1)))

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        self._metric_all = FBetaMeasure()
        self._metric_avg = NERF1Metric()

    @overrides
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            spans: torch.IntTensor,
            labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1])

        # Shape:    (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Shape: (batch_size, num_spans_to_keep, class_num + 1)
        ne_scores = self._compute_named_entity_scores(top_span_embeddings)

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_named_entities = ne_scores.max(2)

        output_dict = {
            "top_spans": top_spans,
            "predicted_named_entities": predicted_named_entities
        }
        if labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices).squeeze(-1)
            negative_log_likelihood = F.cross_entropy(
                ne_scores.reshape(-1, self.class_num),
                pruned_gold_labels.reshape(-1))

            pruner_loss = F.binary_cross_entropy_with_logits(
                top_span_mention_scores.reshape(-1),
                (pruned_gold_labels.reshape(-1) != 0).float())
            loss = negative_log_likelihood + pruner_loss
            output_dict["loss"] = loss
            output_dict["pruner_loss"] = pruner_loss
            batch_size, _ = labels.shape
            all_scores = ne_scores.new_zeros(
                [batch_size * num_spans, self.class_num])
            all_scores[:, 0] = 1
            all_scores[flat_top_span_indices] = ne_scores.reshape(
                -1, self.class_num)
            all_scores = all_scores.reshape(
                [batch_size, num_spans, self.class_num])
            self._metric_all(all_scores, labels)
            self._metric_avg(all_scores, labels)
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False, prefix=""):
        metric = self._metric_all.get_metric(reset)
        metric2 = self._metric_avg.get_metric(reset)
        metric.update(metric2)
        return metric

    def _compute_named_entity_scores(
            self, span_embeddings: torch.FloatTensor) -> torch.Tensor:
        """
        Parameters
        ----------
        span_embeddings: ``torch.FloatTensor``, required.
            Embedding representations of spans. Has shape
            (batch_size, num_spans_to_keep, encoding_dim)
        """
        # Shape: (batch_size, num_spans_to_keep, class_num)
        scores = self._entity_scorer(span_embeddings)
        # Shape: (batch_size, num_spans_to_keep, 1)
        shape = [scores.size(0), scores.size(1), 1]
        dummy_scores = scores.new_full(shape, 0)
        ne_scores = torch.cat([dummy_scores, scores], -1)
        return ne_scores
Esempio n. 6
0
class SrlE2e(Model):
    """

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    model : `Union[str, BertModel]`, required.
        A string describing the BERT model to load or an already constructed BertModel.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    label_smoothing : `float`, optional (default = `0.0`)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    ignore_span_metric : `bool`, optional (default = `False`)
        Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
    srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
        The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
        which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        mention_feedforward: FeedForward,
        context_layer: Seq2SeqEncoder = None,
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        max_span_width: int = 30,
        feature_size: int = 10,
        spans_per_word: float = 100,
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("span_labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric

        self._mention_feedforward = TimeDistributed(mention_feedforward)
        self._mention_scorer = TimeDistributed(
            torch.nn.Linear(mention_feedforward.get_output_dim(), 1))

        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=self.bert_model.config.hidden_size)
        self.span_representation_dim = self._attentive_span_extractor.get_output_dim(
        )
        self._context_layer = context_layer
        if context_layer is not None:
            self._endpoint_span_extractor = EndpointSpanExtractor(
                context_layer.get_output_dim(),
                combination="x,y",
                num_width_embeddings=max_span_width,
                span_width_embedding_dim=feature_size,
                bucket_widths=False,
            )
            self.span_representation_dim = self._endpoint_span_extractor.get_output_dim(
            )

        self.hidden_layer = torch.nn.Sequential(
            torch.nn.Linear(self.span_representation_dim +
                            self.bert_model.config.hidden_size,
                            self.span_representation_dim,
                            bias=False), torch.nn.ReLU())
        self.output_layer = torch.nn.Linear(self.span_representation_dim,
                                            self.num_classes - 1,
                                            bias=False)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self._bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        initializer(self)

    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        sentence_end: torch.LongTensor,
        spans: torch.LongTensor,
        span_labels: torch.LongTensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        start = time.time()
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            # token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            bert_embeddings, spans)

        if self._context_layer is not None:
            contextualized_embeddings = self._context_layer(
                embedded_text_input, mask)
            # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
            endpoint_span_embeddings = self._endpoint_span_extractor(
                contextualized_embeddings, spans)

            # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
            # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)
            span_embeddings = endpoint_span_embeddings
        else:
            span_embeddings = attended_span_embeddings

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * sequence_length))
        num_spans = spans.shape[1]
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep)
        verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat(
            1, 1, embedded_text_input.shape[-1])
        verb_embeddings = torch.gather(embedded_text_input, 1, verb_index)
        assert len(
            verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1
        verb_embeddings = verb_embeddings.squeeze(1)
        # print(verb_indicator.sum(1, keepdim=True) > 0)
        verb_embeddings = torch.where(
            (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                1, verb_embeddings.shape[-1]), verb_embeddings,
            torch.zeros_like(verb_embeddings))
        # print(verb_embeddings)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, spans.shape[1])
        span_embeddings = util.batched_index_select(span_embeddings,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)
        top_span_labels = util.batched_index_select(
            span_labels.unsqueeze(-1), top_span_indices,
            flat_top_span_indices).squeeze(-1)
        concatenated_span_embeddings = torch.cat(
            (span_embeddings, verb_embeddings.unsqueeze(1).repeat(
                1, span_embeddings.shape[1], 1)),
            dim=2)
        # print(concatenated_span_embeddings[:,:,:])
        hidden = self.hidden_layer(concatenated_span_embeddings)
        # print(hidden[1,:,:])
        # print(top_span_indices)
        # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])])
        # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels"))
        predictions = self.output_layer(hidden)
        # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)
        predictions = torch.cat(
            (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1)
        # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1))

        output_dict = {}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]),
                                  top_span_labels.view(-1)) *
                    top_span_mask.float().view(-1)
                    ).sum() / top_span_mask.float().sum()
            # print(top_span_labels)
            # print(predictions.argmax(-1))
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.get_tags(
                    top_spans, predictions, mask.shape[1], top_span_mask,
                    output_dict)
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print('G', batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict

    def get_tags(self, spans, logits, sequence_length, span_mask, output_dict):
        predicted_tag_ids = logits.argmax(2)
        predicted_tags = []
        for i in range(spans.shape[0]):
            sequence = ["O" for _ in range(sequence_length)]
            for j in range(spans.shape[1]):
                if span_mask[i, j].item() == 0:
                    continue
                tag = predicted_tag_ids[i, j].item()
                if tag != self.vocab.get_token_index("O",
                                                     namespace="span_labels"):
                    start = spans[i, j, 0].item()
                    end = spans[i, j, 1].item()
                    if all([el == "O" for el in sequence[start:end + 1]]):
                        tag = self.vocab.get_token_from_index(
                            tag, namespace="span_labels")
                        sequence[start] = "B-" + tag
                        for index in range(start + 1, end + 1):
                            sequence[index] = "I-" + tag
            predicted_tags.append(
                [sequence[ind] for ind in output_dict["wordpiece_offsets"][i]])
        print(predicted_tags)
        return predicted_tags

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.

        NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
        decoding produces low quality output if you decode on top of word representations directly,
        because the model gets confused by the 'missing' positions (which is sensible as it is trained
        to perform tagging on wordpieces, not words).

        Secondly, it's important that the indices we use to recover words from the wordpieces are the
        start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
        tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
        when we select out the word tags from the wordpiece tags. This happens in the case that a word
        is split into multiple word pieces, and then we take the last tag of the word, which might
        correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        wordpiece_tags = []
        word_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        # **************** Different ********************
        # We add in the offsets here so we can compute the un-wordpieced tags.
        for predictions, length, offsets in zip(
                predictions_list, sequence_lengths,
                output_dict["wordpiece_offsets"]):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(x, namespace="labels")
                for x in max_likelihood_sequence
            ]

            wordpiece_tags.append(tags)
            word_tags.append([tags[i] for i in offsets])
        output_dict["wordpiece_tags"] = wordpiece_tags
        output_dict["tags"] = word_tags
        return output_dict

    def get_metrics(self, reset: bool = False):
        if self.ignore_span_metric:
            # Return an empty dictionary if ignoring the
            # span metric
            return {}

        else:
            metric_dict = self.span_metric.get_metric(reset=reset)

            # This can be a lot of metrics, as there are 3 per class.
            # we only really care about the overall metrics, so we filter for them here.
            return {x: y for x, y in metric_dict.items() if "overall" in x}

    def get_viterbi_pairwise_potentials(self):
        """
        Generate a matrix of pairwise transition potentials for the BIO labels.
        The only constraint implemented here is that I-XXX labels must be preceded
        by either an identical I-XXX tag or a B-XXX tag. In order to achieve this
        constraint, pairs of labels which do not satisfy this constraint have a
        pairwise potential of -inf.

        # Returns

        transition_matrix : `torch.Tensor`
            A `(num_labels, num_labels)` matrix of pairwise potentials.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])

        for i, previous_label in all_labels.items():
            for j, label in all_labels.items():
                # I labels can only be preceded by themselves or
                # their corresponding B tag.
                if i != j and label[
                        0] == "I" and not previous_label == "B" + label[1:]:
                    transition_matrix[i, j] = float("-inf")
        return transition_matrix

    def get_start_transitions(self):
        """
        In the BIO sequence, we cannot start the sequence with an I-XXX tag.
        This transition sequence is passed to viterbi_decode to specify this constraint.

        # Returns

        start_transitions : `torch.Tensor`
            The pairwise potentials between a START token and
            the first token of the sequence.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)

        start_transitions = torch.zeros(num_labels)

        for i, label in all_labels.items():
            if label[0] == "I":
                start_transitions[i] = float("-inf")

        return start_transitions

    default_predictor = "semantic_role_labeling"