예제 #1
0
    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        multilabel: bool = False,
        label_weights: Optional[Union[List[float], Dict[str, float]]] = None,
    ):
        super().__init__(backbone)
        vocabulary.set_labels(self.backbone.vocab, labels)

        # label related configurations
        self._multilabel = multilabel

        # metrics and loss
        if isinstance(label_weights, list):
            label_weights = torch.tensor(label_weights, dtype=torch.float32)
        elif isinstance(label_weights, dict):
            label_weights = torch.tensor(
                [label_weights[label] for label in labels],
                dtype=torch.float32)
        if self._multilabel:
            self._loss = torch.nn.BCEWithLogitsLoss(weight=label_weights)
            self._metrics = Metrics(
                micro={
                    "type": "fbeta_multi_label",
                    "average": "micro"
                },
                macro={
                    "type": "fbeta_multi_label",
                    "average": "macro"
                },
                per_label={
                    "type": "fbeta_multi_label",
                    "labels": [i for i in range(len(labels))],
                },
            )
        else:
            self._loss = torch.nn.CrossEntropyLoss(weight=label_weights)
            self._metrics = Metrics(
                accuracy={"type": "categorical_accuracy"},
                micro={
                    "type": "fbeta",
                    "average": "micro"
                },
                macro={
                    "type": "fbeta",
                    "average": "macro"
                },
                per_label={
                    "type": "fbeta",
                    "labels": [i for i in range(len(labels))]
                },
            )
예제 #2
0
    def __init__(
        self,
        backbone: ModelBackbone,
        dropout: float = None,
        bidirectional: bool = False,
    ) -> None:
        super(LanguageModelling, self).__init__(backbone)

        self._empty_prediction = LanguageModellingPrediction(
            lm_embeddings=numpy.array([]), mask=numpy.array([]))

        self.bidirectional = bidirectional

        if not backbone.featurizer.has_word_features:
            raise ConfigurationError(
                "`LanguageModelling` defines a word-level next token language model. "
                "Please check your `features` configuration to enable at least `words` features."
            )

        if backbone.encoder.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {backbone.encoder.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        if self.bidirectional:
            self._forward_dim = backbone.encoder.get_output_dim() // 2
        else:
            self._forward_dim = backbone.encoder.get_output_dim()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self._metrics = Metrics(perplexity={"type": "perplexity"})

        self._loss = SoftmaxLoss(
            num_words=vocabulary.words_vocab_size(self.backbone.vocab),
            embedding_dim=self._forward_dim,
        )
예제 #3
0
def test_metrics():
    metrics = Metrics(
        accuracy={"type": "categorical_accuracy"},
        f1={
            "type": "span_f1",
            "vocabulary": Vocabulary.empty(),
        },
    )

    # Check that training and validation metrics are different instances
    assert (metrics.get_dict()["accuracy"]
            is not metrics.get_dict(is_train=False)["accuracy"])
    # Check if we share the same vocab
    assert (metrics.get_dict()["f1"]._label_vocabulary is
            metrics.get_dict(is_train=False)["f1"]._label_vocabulary)
예제 #4
0
class ClassificationHead(TaskHead):
    """Base abstract class for classification problems

    Parameters
    ----------
    labels
        A list of labels for your classification task
    multilabel
        Is this a multi label classification task? Default: False
    label_weights
        A list of weights for each label. The weights must be in the same order as the `labels`.
        You can also provide a dictionary that maps the label to its weight. Default: None.
    """

    task_name = TaskName.text_classification
    _LOGGER = logging.getLogger(__name__)

    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        multilabel: bool = False,
        label_weights: Optional[Union[List[float], Dict[str, float]]] = None,
    ):
        super().__init__(backbone)
        vocabulary.set_labels(self.backbone.vocab, labels)

        # label related configurations
        self._multilabel = multilabel

        # metrics and loss
        if isinstance(label_weights, list):
            label_weights = torch.tensor(label_weights, dtype=torch.float32)
        elif isinstance(label_weights, dict):
            label_weights = torch.tensor(
                [label_weights[label] for label in labels],
                dtype=torch.float32)
        if self._multilabel:
            self._loss = torch.nn.BCEWithLogitsLoss(weight=label_weights)
            self._metrics = Metrics(
                micro={
                    "type": "fbeta_multi_label",
                    "average": "micro"
                },
                macro={
                    "type": "fbeta_multi_label",
                    "average": "macro"
                },
                per_label={
                    "type": "fbeta_multi_label",
                    "labels": [i for i in range(len(labels))],
                },
            )
        else:
            self._loss = torch.nn.CrossEntropyLoss(weight=label_weights)
            self._metrics = Metrics(
                accuracy={"type": "categorical_accuracy"},
                micro={
                    "type": "fbeta",
                    "average": "micro"
                },
                macro={
                    "type": "fbeta",
                    "average": "macro"
                },
                per_label={
                    "type": "fbeta",
                    "labels": [i for i in range(len(labels))]
                },
            )

    def _add_label(
        self,
        instance: Instance,
        label: Union[List[str], List[int], str, int],
        to_field: str = "label",
    ) -> Instance:
        """Adds the label field for classification into the instance data

        Helper function for the child's `self.featurize` method.

        Parameters
        ----------
        instance
            Add a label field to this instance
        label
            The label data
        to_field
            Name space of the field

        Returns
        -------
        instance
            If `label` is not None, return `instance` with the a label field added.
            Otherwise return just the given `instance`.

        Raises
        ------
        FeaturizeError
            If the label is an empty string or does not match the type:
            - (str, int) for single label
            - (list, np.array) for multi label
        """
        # "if not label:" fails for ndarrays, this is why we explicitly check for None
        if label is None:
            return instance

        field = None
        # check if multilabel and if adequate type
        if self._multilabel and isinstance(label, (list, numpy.ndarray)):
            label = label.tolist() if isinstance(label,
                                                 numpy.ndarray) else label
            field = MultiLabelField(
                label, label_namespace=vocabulary.LABELS_NAMESPACE)
        # check if not multilabel and adequate type + check for empty strings
        if not self._multilabel and isinstance(label, (str, int)) and label:
            field = LabelField(label,
                               label_namespace=vocabulary.LABELS_NAMESPACE)
        if not field:
            # We have label info but we cannot build the label field --> discard the instance
            raise FeaturizeError(
                f"Cannot create label field for `label={label}`!")

        instance.add_field(to_field, field)

        return instance

    def _make_forward_output(
            self, logits: torch.Tensor,
            label: Optional[torch.IntTensor]) -> Dict[str, Any]:
        """Returns a dict with the logits and optionally the loss

        Helper function for the child's `self.forward` method.
        """
        if label is not None:
            return {
                "loss": self._compute_metrics_and_return_loss(logits, label),
                "logits": logits,
            }

        return {"logits": logits}

    def _compute_metrics_and_return_loss(self, logits: torch.Tensor,
                                         label: torch.IntTensor) -> float:
        """Helper function for the `self._make_forward_output` method."""
        for metric in self._metrics.get_dict(is_train=self.training).values():
            metric(logits, label)

        if self._multilabel:
            # casting long to float for BCELoss
            # see https://discuss.pytorch.org/t/nn-bcewithlogitsloss-cant-accept-one-hot-target/59980
            return self._loss(
                logits.view(-1, self.num_labels),
                label.view(-1, self.num_labels).type_as(logits),
            )

        return self._loss(logits, label.long())

    def _compute_labels_and_probabilities(
        self,
        single_forward_output: Dict[str, numpy.ndarray],
    ) -> Tuple[List[str], List[float]]:
        """Computes the probabilities based on the logits and looks up the labels

        This is a helper function for the `self._make_task_prediction` of the children.

        Parameters
        ----------
        single_forward_output
            A single (not batched) output from the head's forward method

        Returns
        -------
        (labels, probabilities)
        """
        logits = torch.from_numpy(single_forward_output["logits"])

        if self._multilabel:
            probabilities = logits.sigmoid()
        else:
            probabilities = torch.nn.functional.softmax(logits, dim=0)

        labels, all_probabilities = (
            self._add_and_sort_labels_and_probabilities(probabilities)
            if self.num_labels > 0 else ([], []))

        return labels, all_probabilities

    def _add_and_sort_labels_and_probabilities(
            self,
            probabilities: torch.Tensor) -> Tuple[List[str], List[float]]:
        """Returns the labels and probabilities sorted by the probability (descending)

        Helper function for the `self._compute_labels_and_probabilities` method. The list of the returned
        probabilities can be larger than the input probabilities, since we add all defined labels in the head.

        Parameters
        ----------
        probabilities
            Probabilities of the model's prediction for one instance

        Returns
        -------
        labels, probabilities
        """
        all_classes_probs = torch.zeros(
            self.num_labels,  # this can be >= probabilities.size()[0]
            device=probabilities.get_device()
            if probabilities.get_device() > -1 else None,
        )
        all_classes_probs[:probabilities.size()[0]] = probabilities
        sorted_indexes_by_prob = torch.argsort(all_classes_probs,
                                               descending=True).tolist()

        labels = [
            vocabulary.label_for_index(self.backbone.vocab, idx)
            for idx in sorted_indexes_by_prob
        ]
        probabilities = [
            float(all_classes_probs[idx]) for idx in sorted_indexes_by_prob
        ]

        return labels, probabilities

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """Get the metrics of our classifier, see :func:`~allennlp_2.models.Model.get_metrics`.

        Parameters
        ----------
        reset
            Reset the metrics after obtaining them?

        Returns
        -------
        A dictionary with all metric names and values.
        """
        metrics, final_metrics = self._metrics.get_dict(
            is_train=self.training), {}
        for name, metric in metrics.items():
            if name == "accuracy":
                final_metrics.update({"accuracy": metric.get_metric(reset)})
            elif name in ["macro", "micro"]:
                final_metrics.update({
                    f"{name}/{key}": value
                    for key, value in metric.get_metric(reset).items()
                })
            elif name == "per_label":
                for key, values in metric.get_metric(reset).items():
                    for i, value in enumerate(values):
                        label = vocabulary.label_for_index(
                            self.backbone.vocab, i)
                        # sanitize label using same patterns as tensorboardX to avoid summary writer warnings
                        label = helpers.sanitize_metric_name(label)
                        final_metrics.update({f"_{key}/{label}": value})

        return final_metrics

    def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
        raise NotImplementedError

    def featurize(self, *args, **kwargs) -> Instance:
        raise NotImplementedError

    def _make_task_prediction(self, single_forward_output: Dict[str,
                                                                numpy.ndarray],
                              instance: Instance) -> TaskPrediction:
        raise NotImplementedError

    def _compute_attributions(
        self,
        single_forward_output: Dict[str, numpy.ndarray],
        instance: Instance,
        **kwargs,
    ) -> List[Union[Attribution, List[Attribution]]]:
        raise NotImplementedError
예제 #5
0
class LanguageModelling(TaskHead):
    """
    Task head for next-token language modelling, i.e., a model to predict the next token
    in a sequence of tokens.
    """

    task_name = TaskName.language_modelling
    _LOGGER = logging.getLogger(__name__)

    def __init__(
        self,
        backbone: ModelBackbone,
        dropout: float = None,
        bidirectional: bool = False,
    ) -> None:
        super(LanguageModelling, self).__init__(backbone)

        self._empty_prediction = LanguageModellingPrediction(
            lm_embeddings=numpy.array([]), mask=numpy.array([]))

        self.bidirectional = bidirectional

        if not backbone.featurizer.has_word_features:
            raise ConfigurationError(
                "`LanguageModelling` defines a word-level next token language model. "
                "Please check your `features` configuration to enable at least `words` features."
            )

        if backbone.encoder.is_bidirectional() is not bidirectional:
            raise ConfigurationError(
                "Bidirectionality of contextualizer must match bidirectionality of "
                "language model. "
                f"Contextualizer bidirectional: {backbone.encoder.is_bidirectional()}, "
                f"language model bidirectional: {bidirectional}")

        if self.bidirectional:
            self._forward_dim = backbone.encoder.get_output_dim() // 2
        else:
            self._forward_dim = backbone.encoder.get_output_dim()

        if dropout:
            self._dropout = torch.nn.Dropout(dropout)
        else:
            self._dropout = lambda x: x

        self._metrics = Metrics(perplexity={"type": "perplexity"})

        self._loss = SoftmaxLoss(
            num_words=vocabulary.words_vocab_size(self.backbone.vocab),
            embedding_dim=self._forward_dim,
        )

    def on_vocab_update(self):
        num_words = vocabulary.words_vocab_size(self.backbone.vocab)
        if len(self._loss.softmax_b) != num_words:
            self._loss = SoftmaxLoss(
                num_words=num_words,
                embedding_dim=self._forward_dim,
            )

    def featurize(self, text: str) -> Optional[Instance]:
        instance = self.backbone.featurizer(text,
                                            to_field="text",
                                            aggregate=True)

        return instance

    def forward(self,
                text: TextFieldTensors) -> Dict[str, Any]:  # type: ignore

        mask = get_text_field_mask(text)
        contextual_embeddings = self.backbone.forward(text, mask)

        token_ids = get_token_ids_from_text_field_tensors(text)
        assert isinstance(contextual_embeddings, torch.Tensor)

        # Use token_ids to compute targets
        # targets are next token ids with respect to first token in the seq
        # e.g. token_ids [[1, 3, 5, 7],..[]], forward_targets=[[3,5,7],..]
        forward_targets = torch.zeros_like(token_ids)
        forward_targets[:, 0:-1] = token_ids[:, 1:]

        if self.bidirectional:
            backward_targets = torch.zeros_like(token_ids)
            backward_targets[:, 1:] = token_ids[:, 0:-1]
        else:
            backward_targets = None

        # add dropout
        contextual_embeddings_with_dropout = self._dropout(
            contextual_embeddings)

        # compute softmax loss
        try:
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout, forward_targets,
                backward_targets)
        except IndexError:
            raise IndexError(
                "Word token out of vocabulary boundaries, please check your vocab is correctly set"
                " or created before starting training.")

        num_targets = torch.sum((forward_targets > 0).long())

        if num_targets > 0:
            if self.bidirectional:
                average_loss = (0.5 * (forward_loss + backward_loss) /
                                num_targets.float())
            else:
                average_loss = forward_loss / num_targets.float()
        else:
            average_loss = torch.tensor(0.0)

        for metric in self._metrics.get_dict(is_train=self.training).values():
            # Perplexity needs the value to be on the cpu
            metric(average_loss.to("cpu"))

        return dict(
            loss=average_loss,
            lm_embeddings=contextual_embeddings,
            mask=mask,
        )

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {
            metric_name: metric.get_metric(reset)
            for metric_name, metric in self._metrics.get_dict(
                is_train=self.training).items()
        }

    def _compute_loss(
        self,
        lm_embeddings: torch.Tensor,
        forward_targets: torch.Tensor,
        backward_targets: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # If bidirectional, lm_embeddings is shape (batch_size, timesteps, dim * 2)
        # If unidirectional, lm_embeddings is shape (batch_size, timesteps, dim)
        # forward_targets, backward_targets (None in the unidirectional case) are
        # shape (batch_size, timesteps) masked with 0
        if self.bidirectional:
            forward_embeddings, backward_embeddings = lm_embeddings.chunk(
                2, -1)
            backward_loss = self._loss_helper(backward_embeddings,
                                              backward_targets)
        else:
            forward_embeddings = lm_embeddings
            backward_loss = None

        forward_loss = self._loss_helper(forward_embeddings, forward_targets)
        return forward_loss, backward_loss

    def _loss_helper(
        self,
        direction_embeddings: torch.Tensor,
        direction_targets: torch.Tensor,
    ) -> torch.Tensor:
        mask = direction_targets > 0
        # we need to subtract 1 to undo the padding id since the softmax
        # does not include a padding dimension

        # shape (batch_size * timesteps, )
        non_masked_targets = direction_targets.masked_select(mask) - 1

        # shape (batch_size * timesteps, embedding_dim)
        non_masked_embeddings = direction_embeddings.masked_select(
            mask.unsqueeze(-1)).view(-1, self._forward_dim)

        return self._loss(non_masked_embeddings, non_masked_targets)

    def _make_task_prediction(
        self,
        single_forward_output: Dict[str, numpy.ndarray],
        instance: Instance,
    ) -> LanguageModellingPrediction:
        task_prediction = LanguageModellingPrediction(
            lm_embeddings=single_forward_output["lm_embeddings"],
            mask=single_forward_output["mask"],
        )
        if "loss" in single_forward_output:
            task_prediction.loss = float(single_forward_output["loss"])

        return task_prediction
예제 #6
0
    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        label_encoding: Optional[str] = "BIOUL",
        top_k: int = 1,
        dropout: Optional[float] = 0.0,
        feedforward: Optional[FeedForwardConfiguration] = None,
    ) -> None:
        super().__init__(backbone)

        self._empty_prediction = TokenClassificationPrediction(tags=[[]],
                                                               entities=[[]],
                                                               scores=[])

        if label_encoding not in ["BIOUL", "BIO"]:
            raise WrongValueError(
                f"Label encoding {label_encoding} not supported. Allowed values are {['BIOUL', 'BIO']}"
            )

        self._span_labels = labels
        self._label_encoding = label_encoding

        vocabulary.set_labels(
            self.backbone.vocab,
            # Convert span labels to tag labels if necessary
            # We just check if "O" is in the label list, a necessary tag for IOB/BIOUL schemes,
            # an unlikely label for spans
            span_labels_to_tag_labels(labels, self._label_encoding),
        )

        self.top_k = top_k
        self.dropout = torch.nn.Dropout(dropout)
        self._feedforward: FeedForward = (
            None if not feedforward else feedforward.input_dim(
                backbone.encoder.get_output_dim()).compile())
        # output layers
        self._classifier_input_dim = (self._feedforward.get_output_dim()
                                      if self._feedforward else
                                      backbone.encoder.get_output_dim())
        # we want this linear applied to each token in the sequence
        self._label_projection_layer = TimeDistributed(
            torch.nn.Linear(self._classifier_input_dim, self.num_labels))
        constraints = allowed_transitions(
            self._label_encoding,
            vocabulary.get_index_to_labels_dictionary(self.backbone.vocab),
        )

        self._crf = ConditionalRandomField(self.num_labels,
                                           constraints,
                                           include_start_end_transitions=True)

        # There is no top_k option for the f1 metric, it will always only take into account the first choice
        # If you want to use top_k in the accuracy, you have to change the way we convert the CRF output to logits!
        self._metrics = Metrics(
            accuracy={"type": "categorical_accuracy"},
            f1={
                "type": "span_f1",
                "vocabulary": self.backbone.vocab,
                "tag_namespace": vocabulary.LABELS_NAMESPACE,
                "label_encoding": self._label_encoding,
            },
        )
예제 #7
0
class TokenClassification(TaskHead):
    """Task head for token classification (NER)

    Parameters
    ----------
    backbone
        The model backbone
    labels
        List span labels. Span labels get converted to tag labels internally, using
        configured label_encoding for that.
    label_encoding
        The format of the tags. Supported encodings are: ['BIO', 'BIOUL']
    top_k
    dropout
    feedforward
    """

    _LOGGER = logging.getLogger(__name__)

    task_name = TaskName.token_classification

    def __init__(
        self,
        backbone: ModelBackbone,
        labels: List[str],
        label_encoding: Optional[str] = "BIOUL",
        top_k: int = 1,
        dropout: Optional[float] = 0.0,
        feedforward: Optional[FeedForwardConfiguration] = None,
    ) -> None:
        super().__init__(backbone)

        self._empty_prediction = TokenClassificationPrediction(tags=[[]],
                                                               entities=[[]],
                                                               scores=[])

        if label_encoding not in ["BIOUL", "BIO"]:
            raise WrongValueError(
                f"Label encoding {label_encoding} not supported. Allowed values are {['BIOUL', 'BIO']}"
            )

        self._span_labels = labels
        self._label_encoding = label_encoding

        vocabulary.set_labels(
            self.backbone.vocab,
            # Convert span labels to tag labels if necessary
            # We just check if "O" is in the label list, a necessary tag for IOB/BIOUL schemes,
            # an unlikely label for spans
            span_labels_to_tag_labels(labels, self._label_encoding),
        )

        self.top_k = top_k
        self.dropout = torch.nn.Dropout(dropout)
        self._feedforward: FeedForward = (
            None if not feedforward else feedforward.input_dim(
                backbone.encoder.get_output_dim()).compile())
        # output layers
        self._classifier_input_dim = (self._feedforward.get_output_dim()
                                      if self._feedforward else
                                      backbone.encoder.get_output_dim())
        # we want this linear applied to each token in the sequence
        self._label_projection_layer = TimeDistributed(
            torch.nn.Linear(self._classifier_input_dim, self.num_labels))
        constraints = allowed_transitions(
            self._label_encoding,
            vocabulary.get_index_to_labels_dictionary(self.backbone.vocab),
        )

        self._crf = ConditionalRandomField(self.num_labels,
                                           constraints,
                                           include_start_end_transitions=True)

        # There is no top_k option for the f1 metric, it will always only take into account the first choice
        # If you want to use top_k in the accuracy, you have to change the way we convert the CRF output to logits!
        self._metrics = Metrics(
            accuracy={"type": "categorical_accuracy"},
            f1={
                "type": "span_f1",
                "vocabulary": self.backbone.vocab,
                "tag_namespace": vocabulary.LABELS_NAMESPACE,
                "label_encoding": self._label_encoding,
            },
        )

    @property
    def span_labels(self) -> List[str]:
        return self._span_labels

    def _loss(self, logits: torch.Tensor, labels: torch.Tensor,
              mask: torch.Tensor):
        """loss is calculated as -log_likelihood from crf"""
        return -1 * self._crf(logits, labels, mask)

    def featurize(
        self,
        text: Union[str, List[str]],
        entities: Optional[List[dict]] = None,
        tags: Optional[Union[List[str], List[int]]] = None,
    ) -> Instance:
        """
        Parameters
        ----------
        text
            Can be either a simple str or a list of str,
            in which case it will be treated as a list of pretokenized tokens
        entities
            A list of span labels

            Span labels are dictionaries that contain:

            'start': int, char index of the start of the span
            'end': int, char index of the end of the span (exclusive)
            'label': str, label of the span

            They are used with the `spacy.gold.biluo_tags_from_offsets` method.
        tags
            A list of tags in the BIOUL or BIO format.
        """
        if isinstance(text, str):
            doc = self.backbone.tokenizer.nlp(text)
            tokens = [spacy_to_allennlp_token(token) for token in doc]
            tags = (tags_from_offsets(doc, entities, self._label_encoding)
                    if entities is not None else [])
            # discard misaligned examples for now
            if "-" in tags:
                raise FeaturizeError(
                    f"Could not align spans with tokens for following example: '{text}' {entities}"
                )
        # text is already pre-tokenized
        else:
            tokens = [Token(t) for t in text]

        instance = self.backbone.featurizer(tokens,
                                            to_field="text",
                                            tokenize=False,
                                            aggregate=True)

        if self.training:
            try:
                instance.add_field(
                    "tags",
                    SequenceLabelField(
                        tags,
                        sequence_field=cast(TextField, instance["text"]),
                        label_namespace=vocabulary.LABELS_NAMESPACE,
                    ),
                )
            except Exception as exception:
                raise FeaturizeError(
                    f"Could not create SequenceLabelField for {(tokens, tags)}"
                ) from exception

        instance.add_field("raw_text", MetadataField(text))

        return instance

    def forward(  # type: ignore
        self,
        text: TextFieldTensors,
        raw_text: List[Union[str, List[str]]],
        tags: torch.IntTensor = None,
    ) -> Dict:

        mask = get_text_field_mask(text)
        embedded_text = self.dropout(self.backbone.forward(text, mask))

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        logits = self._label_projection_layer(embedded_text)
        # `self._crf.viterbi_tags` can return invalid tag sequences when logits are nan
        viterbi_logits = torch.where(torch.isnan(logits),
                                     torch.zeros_like(logits), logits)
        # dims are: batch, top_k, (tag_sequence, viterbi_score)
        viterbi_paths: List[List[Tuple[List[int],
                                       float]]] = self._crf.viterbi_tags(
                                           viterbi_logits,
                                           mask,
                                           top_k=self.top_k)
        # We just keep the best path for every instance
        predicted_tags: List[List[int]] = [
            paths[0][0] for paths in viterbi_paths
        ]
        class_probabilities = torch.zeros_like(logits)

        for i, instance_tags in enumerate(predicted_tags):
            for j, tag_id in enumerate(instance_tags):
                class_probabilities[i, j, tag_id] = 1

        output = dict(
            viterbi_paths=viterbi_paths,
            raw_text=raw_text,
        )

        if tags is not None:
            output["loss"] = self._loss(logits, tags, mask)
            for metric in self._metrics.get_dict(
                    is_train=self.training).values():
                metric(class_probabilities, tags, mask)

        return output

    def _make_task_prediction(
        self,
        single_forward_output: Dict,
        instance: Instance,
    ) -> TokenClassificationPrediction:
        # The dims are: top_k, tags
        tags: List[List[str]] = self._make_tags(
            single_forward_output["viterbi_paths"])
        # construct a spacy Doc
        pre_tokenized = not isinstance(single_forward_output["raw_text"], str)
        if pre_tokenized:
            # compose doc from tokens
            doc = Doc(Vocab(), words=single_forward_output["raw_text"])
        else:
            doc = self.backbone.tokenizer.nlp(
                single_forward_output["raw_text"])

        return TokenClassificationPrediction(
            tags=tags,
            scores=[
                score for tags, score in single_forward_output["viterbi_paths"]
            ],
            entities=self._make_entities(doc, tags, pre_tokenized),
        )

    def _make_tags(
            self, viterbi_paths: List[Tuple[List[int],
                                            float]]) -> List[List[str]]:
        """Makes the 'tags' key of the task prediction"""
        return [[
            vocabulary.label_for_index(self.backbone.vocab, idx)
            for idx in tags
        ] for tags, score in viterbi_paths]

    def _make_entities(
        self,
        doc: Doc,
        k_tags: List[List[str]],
        pre_tokenized: bool,
    ) -> List[List[Entity]]:
        """Makes the 'entities' key of the task prediction. Computes offsets with respect to char and token id"""
        return [[
            Entity(**entity)
            for entity in offsets_from_tags(doc,
                                            tags,
                                            self._label_encoding,
                                            only_token_spans=pre_tokenized)
        ] for tags in k_tags]

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        for name, metric in self._metrics.get_dict(
                is_train=self.training).items():
            metric_value = metric.get_metric(reset=reset)
            try:
                metrics.update(metric_value)
            # AllenNLP's CategoricalAccuracy does not comply with AllenNLP's Metric API
            except TypeError:
                metrics[name] = metric_value

        return metrics