Example #1
class StreusleTaggerRoberta(Model):
    The ``StreusleTaggerRoberta`` embeds a sequence of text (via RoBERTa) before (optionally)
    encoding it with a ``Seq2SeqEncoder`` and passing it through a ``FeedForward``
    before using a Conditional Random Field model to predict a STREUSLE lextag for
    each token in the sequence. Decoding is constrained with the "BbIiOo_~"
    tagging scheme.

    This is functionally the same as StreusleTagger , except it uses RoBERTa instead of
    a configurable TokenEmbedder (since there are some unsolved AllenNLP issues about properly
    using RoBERTa).

    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    roberta_type: ``str``, required
        The type of RoBERTa model to use (``base`` or ``large``).
    train_roberta: ``bool``, optional (default=``False``)
        If True, update roberta weights during training. Else, freeze them.
    encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        The encoder that we will use in between embedding tokens and predicting output tags.
    label_namespace : ``str``, optional (default=``labels``)
        This is needed to constrain the CRF decoding.
        Unless you did something unusual, the default value should be what you want.
    feedforward : ``FeedForward``, optional, (default = None).
        An optional feedforward layer to apply after the encoder.
    include_start_end_transitions : ``bool``, optional (default=``True``)
        Whether to include start and end transition parameters in the CRF.
    dropout:  ``float``, optional (default=``None``)
    use_upos_constraints : ``bool``, optional (default=``True``)
        Whether to use UPOS constraints. If True, model shoudl recieve UPOS as input.
    use_lemma_constraints : ``bool``, optional (default=``True``)
        Whether to use lemma constraints. If True, model shoudl recieve lemmas as input.
        If this is true, then use_upos_constraints must be true as well.
    train_with_constraints : ``bool``, optional (default=``True``)
        Whether to use the constraints during training, or only during testing.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.

    def __init__(self,
                 vocab: Vocabulary,
                 roberta_type: str,
                 train_roberta: bool = False,
                 encoder: Seq2SeqEncoder = None,
                 label_namespace: str = "labels",
                 feedforward: Optional[FeedForward] = None,
                 include_start_end_transitions: bool = True,
                 dropout: Optional[float] = None,
                 use_upos_constraints: bool = True,
                 use_lemma_constraints: bool = True,
                 train_with_constraints: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.label_namespace = label_namespace
        self.roberta_config = AutoConfig.from_pretrained(
        self.roberta_config.output_hidden_states = True
        self.roberta = AutoModel.from_pretrained(f"roberta-{roberta_type}",
        self.scalar_mix = ScalarMix(self.roberta.config.num_hidden_layers + 1)

        for parameter in self.roberta.parameters():
            parameter.requires_grad = train_roberta

        self.num_tags = self.vocab.get_vocab_size(label_namespace)
        self.train_with_constraints = train_with_constraints

        self.encoder = encoder
        if self.encoder is not None:
            encoder_output_dim = self.encoder.get_output_dim()
            encoder_output_dim = self.roberta.config.hidden_size
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.dropout = None
        self.feedforward = feedforward

        if feedforward is not None:
            output_dim = feedforward.get_output_dim()
            output_dim = encoder_output_dim
        self.tag_projection_layer = TimeDistributed(
            Linear(output_dim, self.num_tags))
        self._label_namespace = label_namespace
        labels = self.vocab.get_index_to_token_vocabulary(
        constraints = streusle_allowed_transitions(labels)

        self.use_upos_constraints = use_upos_constraints
        self.use_lemma_constraints = use_lemma_constraints

        if self.use_lemma_constraints and not self.use_upos_constraints:
            raise ConfigurationError(
                "If lemma constraints are applied, UPOS constraints must be applied as well."

        if self.use_upos_constraints:
            # Get a dict with a mapping from UPOS to allowed LEXCAT here.
            self._upos_to_allowed_lexcats: Dict[
                str, Set[str]] = get_upos_allowed_lexcats(
            # Dict with a amapping from UPOS to dictionary of [UPOS, list of additionally allowed LEXCATS]
            self._lemma_to_allowed_lexcats: Dict[str, Dict[
                str, List[str]]] = get_lemma_allowed_lexcats()

            # Use labels and the upos_to_allowed_lexcats to get a dict with
            # a mapping from UPOS to a mask with 1 at allowed label indices and 0 at
            # disallowed label indices.
            self._upos_to_label_mask: Dict[str, torch.Tensor] = {}
            for upos in ALL_UPOS:
                # Shape: (num_labels,)
                upos_label_mask = torch.zeros(
                # Go through the labels and indices and fill in the values that are allowed.
                for label_index, label in labels.items():
                    if len(label.split("-")) == 1:
                        upos_label_mask[label_index] = 1
                    label_lexcat = label.split("-")[1]
                    if not label.startswith("O-") and not label.startswith(
                        # Label does not start with O-/o-, always allowed.
                        upos_label_mask[label_index] = 1
                    elif label_lexcat in self._upos_to_allowed_lexcats[upos]:
                        # Label starts with O-/o-, but the lexcat is in allowed
                        # lexcats for the current upos.
                        upos_label_mask[label_index] = 1
                self._upos_to_label_mask[upos] = upos_label_mask

            # Use labels and the lemma_to_allowed_lexcats to get a dict with
            # a mapping from lemma to a mask with 1 at an _additionally_ allowed label index
            # and 0 at disallowed label indices. If lemma_to_label_mask has a 0, and upos_to_label_mask
            # has a 0, the lexcat is not allowed for the (upos, lemma). If either lemma_to_label_mask or
            # upos_to_label_mask has a 1, the lexcat is allowed for the (upos, lemma) pair.
            self._lemma_upos_to_label_mask: Dict[Tuple[str, str],
                                                 torch.Tensor] = {}
            for lemma in SPECIAL_LEMMAS:
                for upos_tag in ALL_UPOS:
                    # No additional constraints, should be all zero
                    if upos_tag not in self._lemma_to_allowed_lexcats[lemma]:
                    # Shape: (num_labels,)
                    lemma_upos_label_mask = torch.zeros(
                    # Go through the labels and indices and fill in the values that are allowed.
                    for label_index, label in labels.items():
                        # For ~i, etc. tags. We don't deal with them here.
                        if len(label.split("-")) == 1:
                        label_lexcat = label.split("-")[1]
                        if not label.startswith("O-") and not label.startswith(
                            # Label does not start with O-/o-, so we don't deal with it here
                        if label_lexcat in self._lemma_to_allowed_lexcats[
                            # Label starts with O-/o-, but the lexcat is in allowed
                            # lexcats for the current upos.
                            lemma_upos_label_mask[label_index] = 1
                        lemma, upos_tag)] = lemma_upos_label_mask

        self.include_start_end_transitions = include_start_end_transitions
        self.crf = ConditionalRandomField(

        self.metrics = {
            "accuracy": CategoricalAccuracy(),
            "accuracy3": CategoricalAccuracy(top_k=3)
        if encoder is not None:
                                   "roberta embedding dim",
                                   "encoder input dim")
        if feedforward is not None:
                                   "encoder output dim",
                                   "feedforward input dim")

    def forward(
            self,  # type: ignore
            tags: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        tags : ``torch.LongTensor``, optional (default = ``None``)
            A torch tensor representing the sequence of integer gold lextags of shape
            ``(batch_size, num_tokens)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Additional information about the example.

        An output dictionary consisting of:

        constrained_logits : ``torch.FloatTensor``
            The constrained logits that are the output of the ``tag_projection_layer``
        mask : ``torch.LongTensor``
            The text field mask for the input tokens
        tags : ``List[List[int]]``
            The predicted tags using the Viterbi algorithm.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
        # Extract the mask over _tokens_ (as opposed to wordpieces) from token_indices_to_wordpiece_indices
        mask = (token_indices_to_wordpiece_indices >= 0).long()
        # Set the negative (padding) values in token_indices_to_wordpiece_indices to 0
        token_indices_to_wordpiece_indices = token_indices_to_wordpiece_indices * mask
        batch_size = token_indices_to_wordpiece_indices.size(0)

        roberta_output = self.roberta(input_ids=input_ids,
        assert len(roberta_output) == 3
        # Tuple of len num_layers, where each tuple item is of shape (batch_size, seq_len, hidden_size)
        all_layer_embedded_wordpieces = roberta_output[2]
        # Combine into a single tensor of (batch_size, seq_len, hidden_size)
        embedded_wordpieces = self.scalar_mix(all_layer_embedded_wordpieces,

        # Shape: (batch_size, 1)
        range_vector = util.get_range_vector(
            batch_size, util.get_device_of(
        embedded_text_input = embedded_wordpieces[
            range_vector, token_indices_to_wordpiece_indices]

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        if self.encoder:
            encoded_text = self.encoder(embedded_text_input, mask)
            encoded_text = embedded_text_input

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        if self.feedforward is not None:
            encoded_text = self.feedforward(encoded_text)

        logits = self.tag_projection_layer(encoded_text)

        # initial mask is unmasked
        batch_upos_constraint_mask = torch.ones_like(logits)
        # Use constraints only if use_upos_constraints is true and we're either
        # (1) in evaluate mode or (2) training with constraints.
        if self.use_upos_constraints and (not self.training
                                          or self.train_with_constraints):
            # List of length (batch_size,), where each inner list is a list of
            # the UPOS tags for the associated token sequence.
            batch_upos_tags = [
                for instance_metadata in metadata

            # List of length (batch_size,), where each inner list is a list of
            # the lemmas for the associated token sequence.
            if self.use_lemma_constraints:
                batch_lemmas = [
                    for instance_metadata in metadata
                batch_lemmas = [([None] * len(instance_metadata["upos_tags"]))
                                for instance_metadata in metadata]

            # Get a (batch_size, max_sequence_length, num_tags) mask with "1" in
            # tags that are allowed for a given UPOS, and "0" for tags that are
            # disallowed for an even UPOS.
            batch_upos_constraint_mask = self.get_upos_constraint_mask(
                batch_upos_tags=batch_upos_tags, batch_lemmas=batch_lemmas)

        constrained_logits = util.replace_masked_values(
            logits, batch_upos_constraint_mask, -1e32)

        best_paths = self.crf.viterbi_tags(constrained_logits, mask)
        # Just get the tags and ignore the score.
        predicted_tags = [x for x, y in best_paths]

        output = {
            [instance_metadata["tokens"] for instance_metadata in metadata]

        if self.use_upos_constraints and (not self.training
                                          or self.train_with_constraints):
            output["constrained_logits"] = constrained_logits
            output["upos_tags"] = batch_upos_tags

        if tags is not None:
            # Add negative log-likelihood as loss
            log_likelihood = self.crf(constrained_logits, tags, mask)
            output["loss"] = -log_likelihood

            # Represent viterbi tags as "class probabilities" that we can
            # feed into the metrics
            class_probabilities = constrained_logits * 0.
            for i, instance_tags in enumerate(predicted_tags):
                for j, tag_id in enumerate(instance_tags):
                    class_probabilities[i, j, tag_id] = 1

            for metric in self.metrics.values():
                metric(class_probabilities, tags, mask.float())
        return output

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        Converts the tag ids to the actual tags.
        ``output_dict["tags"]`` is a list of lists of tag_ids,
        so we use an ugly nested list comprehension.
        output_dict["tags"] = [[
            for tag in instance_tags
        ] for instance_tags in output_dict["tags"]]

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics_to_return = {
            metric_name: metric.get_metric(reset)
            for metric_name, metric in self.metrics.items()
        return metrics_to_return

    def get_upos_constraint_mask(self, batch_upos_tags: List[List[str]],
                                 batch_lemmas: List[List[str]]):
        Given POS tags and lemmas for a batch, return a mask of shape
        (batch_size, max_sequence_length, num_tags) mask with "1" in
        tags that are allowed for a given UPOS, and "0" for tags that are
        disallowed for a given UPOS.

        batch_upos_tags: ``List[List[str]]``, required
            UPOS tags for a batch.
        batch_lemmas: ``List[List[str]]``, required
            Lemmas for a batch.

        ``Tensor``, shape (batch_size, max_sequence_length, num_tags)
            A mask over the logits, with 1 in positions where a tag is allowed
            for its UPOS and 0 in positions where a tag is allowed for its UPOS.
        # TODO(nfliu): this is pretty inefficient, maybe there's someway to make it batched?
        # Shape: (batch_size, max_sequence_length, num_tags)
        upos_constraint_mask = torch.ones(
            len(max(batch_upos_tags, key=len)),
            device=next(self.tag_projection_layer.parameters()).device) * -1e32
        # Iterate over the batch
        for example_index, (example_upos_tags, example_lemmas) in enumerate(
                zip(batch_upos_tags, batch_lemmas)):
            # Shape of example_constraint_mask: (max_sequence_length, num_tags)
            # Iterate over the upos tags for the example
            example_constraint_mask = upos_constraint_mask[example_index]
            for timestep_index, (timestep_upos_tag,
                                 timestep_lemma) in enumerate(  # pylint: disable=unused-variable
                                     zip(example_upos_tags, example_lemmas)):
                # Shape of timestep_constraint_mask: (num_tags,)
                upos_constraint = self._upos_to_label_mask[timestep_upos_tag]
                lemma_constraint = self._lemma_upos_to_label_mask.get(
                    (timestep_lemma, timestep_upos_tag),
                example_constraint_mask[timestep_index] = (
                    upos_constraint.long() | lemma_constraint.long()).float()
        return upos_constraint_mask
Example #2
class StreusleTaggerLinear(Model):
    The ``StreusleTaggerLinear`` embeds a sequence of text before (optionally)
    encoding it with a ``Seq2SeqEncoder`` and passing it through a ``FeedForward``
    before using a linear layer to predict a STREUSLE lextag for
    each token in the sequence.

    vocab : ``Vocabulary``, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    text_field_embedder : ``TextFieldEmbedder``, required
        Used to embed the tokens ``TextField`` we get as input to the model.
    encoder : ``Seq2SeqEncoder``, optional (default=``None``)
        The encoder that we will use in between embedding tokens and predicting output tags.
    label_namespace : ``str``, optional (default=``labels``)
        This is needed to constrain the model.
        Unless you did something unusual, the default value should be what you want.
    feedforward : ``FeedForward``, optional, (default = None).
        An optional feedforward layer to apply after the encoder.
    dropout:  ``float``, optional (default=``None``)
    use_upos_constraints : ``bool``, optional (default=``True``)
        Whether to use UPOS constraints. If True, model shoudl recieve UPOS as input.
    use_lemma_constraints : ``bool``, optional (default=``True``)
        Whether to use lemma constraints. If True, model shoudl recieve lemmas as input.
        If this is true, then use_upos_constraints must be true as well.
    train_with_constraints : ``bool``, optional (default=``True``)
        Whether to use the constraints during training, or only during testing.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.

    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder = None,
                 label_namespace: str = "labels",
                 feedforward: Optional[FeedForward] = None,
                 dropout: Optional[float] = None,
                 use_upos_constraints: bool = True,
                 use_lemma_constraints: bool = True,
                 train_with_constraints: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super().__init__(vocab, regularizer)

        self.label_namespace = label_namespace
        self.text_field_embedder = text_field_embedder
        self.num_tags = self.vocab.get_vocab_size(label_namespace)
        self.train_with_constraints = train_with_constraints

        self.encoder = encoder
        if self.encoder is not None:
            encoder_output_dim = self.encoder.get_output_dim()
            encoder_output_dim = self.text_field_embedder.get_output_dim()
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
            self.dropout = None
        self.feedforward = feedforward

        if feedforward is not None:
            output_dim = feedforward.get_output_dim()
            output_dim = encoder_output_dim
        self.tag_projection_layer = TimeDistributed(
            Linear(output_dim, self.num_tags))
        self._label_namespace = label_namespace
        labels = self.vocab.get_index_to_token_vocabulary(
        self.use_upos_constraints = use_upos_constraints
        self.use_lemma_constraints = use_lemma_constraints

        if self.use_lemma_constraints and not self.use_upos_constraints:
            raise ConfigurationError(
                "If lemma constraints are applied, UPOS constraints must be applied as well."

        if self.use_upos_constraints:
            # Get a dict with a mapping from UPOS to allowed LEXCAT here.
            self._upos_to_allowed_lexcats: Dict[
                str, Set[str]] = get_upos_allowed_lexcats(
            # Dict with a amapping from UPOS to dictionary of [UPOS, list of additionally allowed LEXCATS]
            self._lemma_to_allowed_lexcats: Dict[str, Dict[
                str, List[str]]] = get_lemma_allowed_lexcats()

            # Use labels and the upos_to_allowed_lexcats to get a dict with
            # a mapping from UPOS to a mask with 1 at allowed label indices and 0 at
            # disallowed label indices.
            self._upos_to_label_mask: Dict[str, torch.Tensor] = {}
            for upos in ALL_UPOS:
                # Shape: (num_labels,)
                upos_label_mask = torch.zeros(
                # Go through the labels and indices and fill in the values that are allowed.
                for label_index, label in labels.items():
                    if len(label.split("-")) == 1:
                        upos_label_mask[label_index] = 1
                    label_lexcat = label.split("-")[1]
                    if not label.startswith("O-") and not label.startswith(
                        # Label does not start with O-/o-, always allowed.
                        upos_label_mask[label_index] = 1
                    elif label_lexcat in self._upos_to_allowed_lexcats[upos]:
                        # Label starts with O-/o-, but the lexcat is in allowed
                        # lexcats for the current upos.
                        upos_label_mask[label_index] = 1
                self._upos_to_label_mask[upos] = upos_label_mask

            # Use labels and the lemma_to_allowed_lexcats to get a dict with
            # a mapping from lemma to a mask with 1 at an _additionally_ allowed label index
            # and 0 at disallowed label indices. If lemma_to_label_mask has a 0, and upos_to_label_mask
            # has a 0, the lexcat is not allowed for the (upos, lemma). If either lemma_to_label_mask or
            # upos_to_label_mask has a 1, the lexcat is allowed for the (upos, lemma) pair.
            self._lemma_upos_to_label_mask: Dict[Tuple[str, str],
                                                 torch.Tensor] = {}
            for lemma in SPECIAL_LEMMAS:
                for upos_tag in ALL_UPOS:
                    # No additional constraints, should be all zero
                    if upos_tag not in self._lemma_to_allowed_lexcats[lemma]:
                    # Shape: (num_labels,)
                    lemma_upos_label_mask = torch.zeros(
                    # Go through the labels and indices and fill in the values that are allowed.
                    for label_index, label in labels.items():
                        # For ~i, etc. tags. We don't deal with them here.
                        if len(label.split("-")) == 1:
                        label_lexcat = label.split("-")[1]
                        if not label.startswith("O-") and not label.startswith(
                            # Label does not start with O-/o-, so we don't deal with it here
                        if label_lexcat in self._lemma_to_allowed_lexcats[
                            # Label starts with O-/o-, but the lexcat is in allowed
                            # lexcats for the current upos.
                            lemma_upos_label_mask[label_index] = 1
                        lemma, upos_tag)] = lemma_upos_label_mask

        self.accuracy_metrics = {
            "accuracy": CategoricalAccuracy(),
            "accuracy3": CategoricalAccuracy(top_k=3)
        if encoder is not None:
                                   "text field embedding dim",
                                   "encoder input dim")
        if feedforward is not None:
                                   "encoder output dim",
                                   "feedforward input dim")

    def forward(
            self,  # type: ignore
            tokens: Dict[str, torch.LongTensor],
            tags: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()``, which should typically be passed directly to a
            ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
            tensors.  At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
            Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
            for the ``TokenIndexers`` when you created the ``TextField`` representing your
            sequence.  The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
            which knows how to combine different word representations into a single vector per
            token in your input.
        tags : ``torch.LongTensor``, optional (default = ``None``)
            A torch tensor representing the sequence of integer gold lextags of shape
            ``(batch_size, num_tokens)``.
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Additional information about the example.

        An output dictionary consisting of:

        constrained_logits : ``torch.FloatTensor``
            The constrained logits that are the output of the ``tag_projection_layer``
        mask : ``torch.LongTensor``
            The text field mask for the input tokens
        tags : ``List[List[int]]``
            The predicted tags.
        loss : ``torch.FloatTensor``, optional
            A scalar loss to be optimised. Only computed if gold label ``tags`` are provided.
        embedded_text_input = self.text_field_embedder(tokens)
        batch_size, sequence_length, _ = embedded_text_input.size()
        mask = util.get_text_field_mask(tokens)

        if self.dropout:
            embedded_text_input = self.dropout(embedded_text_input)

        if self.encoder:
            encoded_text = self.encoder(embedded_text_input, mask)
            encoded_text = embedded_text_input

        if self.dropout:
            encoded_text = self.dropout(encoded_text)

        if self.feedforward is not None:
            encoded_text = self.feedforward(encoded_text)

        logits = self.tag_projection_layer(encoded_text)

        # initial mask is unmasked
        batch_upos_constraint_mask = torch.ones_like(logits)
        # Use constraints only if use_upos_constraints is true and we're either
        # (1) in evaluate mode or (2) training with constraints.
        if self.use_upos_constraints and (not self.training
                                          or self.train_with_constraints):
            # List of length (batch_size,), where each inner list is a list of
            # the UPOS tags for the associated token sequence.
            batch_upos_tags = [
                for instance_metadata in metadata

            # List of length (batch_size,), where each inner list is a list of
            # the lemmas for the associated token sequence.
            if self.use_lemma_constraints:
                batch_lemmas = [
                    for instance_metadata in metadata
                batch_lemmas = [([None] * len(instance_metadata["upos_tags"]))
                                for instance_metadata in metadata]

            # Get a (batch_size, max_sequence_length, num_tags) mask with "1" in
            # tags that are allowed for a given UPOS, and "0" for tags that are
            # disallowed for an even UPOS.
            batch_upos_constraint_mask = self.get_upos_constraint_mask(
                batch_upos_tags=batch_upos_tags, batch_lemmas=batch_lemmas)
        logits = util.replace_masked_values(logits, batch_upos_constraint_mask,
        class_probabilities = F.softmax(logits, dim=-1).view(
            [batch_size, sequence_length, self.num_tags])

        output = {
            [instance_metadata["tokens"] for instance_metadata in metadata]

        if self.use_upos_constraints and (not self.training
                                          or self.train_with_constraints):
            output["constrained_logits"] = logits
            output["upos_tags"] = batch_upos_tags

        if tags is not None:
            # Add gold tags if they exist
            output["gold_tags"] = tags
            loss = util.sequence_cross_entropy_with_logits(logits, tags, mask)
            output["loss"] = loss
            for metric in self.accuracy_metrics.values():
                metric(class_probabilities, tags, mask.float())
        return output

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:

        ``output_dict["tags"]`` and ``output_dict["gold_tags"]`` are lists of lists of tag_ids,
        so we use an ugly nested list comprehension.
        mask = output_dict.pop("mask")
        lengths = util.get_lengths_from_binary_sequence_mask(mask)
        # Do a position-wise argmax over the class probabilities to recover the tags.
        all_predictions = output_dict["class_probabilities"]
        all_predictions = all_predictions.cpu().data.numpy()
        if all_predictions.ndim == 3:
            predictions_list = [
                all_predictions[i] for i in range(all_predictions.shape[0])
            predictions_list = [all_predictions]
        all_tags = []
        for predictions, length in zip(predictions_list, lengths):
            argmax_indices = numpy.argmax(predictions, axis=-1)
            tags = [
                for x in argmax_indices[:length]
        output_dict["tags"] = all_tags

        # Converts the tag ids to the actual tags.
        gold_tags = output_dict.pop("gold_tags", None)
        if tags is not None:
            # TODO (nfliu): figure out why this is sometimes a tensor and sometimes a list.
            if torch.is_tensor(gold_tags):
                gold_tags = gold_tags.cpu().detach().numpy()
            output_dict["gold_tags"] = [[
                for gold_tag in instance_gold_tags[:length]
            ] for instance_gold_tags, length in zip(gold_tags, lengths)]
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics_to_return = {
            metric_name: metric.get_metric(reset)
            for metric_name, metric in self.accuracy_metrics.items()
        return metrics_to_return

    def get_upos_constraint_mask(self, batch_upos_tags: List[List[str]],
                                 batch_lemmas: List[List[str]]):
        Given POS tags and lemmas for a batch, return a mask of shape
        (batch_size, max_sequence_length, num_tags) mask with "1" in
        tags that are allowed for a given UPOS, and "0" for tags that are
        disallowed for a given UPOS.

        batch_upos_tags: ``List[List[str]]``, required
            UPOS tags for a batch.
        batch_lemmas: ``List[List[str]]``, required
            Lemmas for a batch.

        ``Tensor``, shape (batch_size, max_sequence_length, num_tags)
            A mask over the logits, with 1 in positions where a tag is allowed
            for its UPOS and 0 in positions where a tag is allowed for its UPOS.
        # TODO(nfliu): this is pretty inefficient, maybe there's someway to make it batched?
        # Shape: (batch_size, max_sequence_length, num_tags)
        upos_constraint_mask = torch.ones(
            len(max(batch_upos_tags, key=len)),
            device=next(self.tag_projection_layer.parameters()).device) * -1e32
        # Iterate over the batch
        for example_index, (example_upos_tags, example_lemmas) in enumerate(
                zip(batch_upos_tags, batch_lemmas)):
            # Shape of example_constraint_mask: (max_sequence_length, num_tags)
            # Iterate over the upos tags for the example
            example_constraint_mask = upos_constraint_mask[example_index]
            for timestep_index, (timestep_upos_tag,
                                 timestep_lemma) in enumerate(  # pylint: disable=unused-variable
                                     zip(example_upos_tags, example_lemmas)):
                # Shape of timestep_constraint_mask: (num_tags,)
                upos_constraint = self._upos_to_label_mask[timestep_upos_tag]
                lemma_constraint = self._lemma_upos_to_label_mask.get(
                    (timestep_lemma, timestep_upos_tag),
                example_constraint_mask[timestep_index] = (
                    upos_constraint.long() | lemma_constraint.long()).float()
        return upos_constraint_mask