Ejemplo n.º 1
0
    def forward(  # type: ignore
        self, tokens: TextFieldTensors, target_ids: TextFieldTensors = None
    ) -> Dict[str, torch.Tensor]:

        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)
        batch_size = embeddings.size(0)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
            final_embeddings = util.get_final_encoder_states(contextual_embeddings, mask)
        else:
            final_embeddings = embeddings[:, -1]

        target_logits = self._language_model_head(self._dropout(final_embeddings))

        vocab_size = target_logits.size(-1)
        probs = torch.nn.functional.softmax(target_logits, dim=-1)
        k = min(vocab_size, 5)  # min here largely because tests use small vocab
        top_probs, top_indices = probs.topk(k=k, dim=-1)

        output_dict = {"probabilities": top_probs, "top_indices": top_indices}

        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(tokens)

        if target_ids is not None:
            targets = util.get_token_ids_from_text_field_tensors(target_ids).view(batch_size)
            target_logits = target_logits.view(batch_size, vocab_size)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        return output_dict
Ejemplo n.º 2
0
    def forward(
        self,
        text: TextFieldTensors,
        masked_text: Optional[TextFieldTensors] = None,
        masked_positions: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:  # type: ignore
        if len(text) != 1:
            raise ValueError(
                "PretrainedTransformerBackbone is only compatible with using a single TokenIndexer"
            )
        mask = util.get_text_field_mask(text)
        encoded_text = self.embedder(text)

        outputs = {
            "encoded_text": encoded_text,
            "encoded_text_mask": mask,
            "token_ids": util.get_token_ids_from_text_field_tensors(text),
        }

        if masked_text is not None and masked_positions is not None:
            masked_text_mask = util.get_text_field_mask(masked_text)
            encoded_masked_text = self.embedder(text)
            outputs["masked_positions"] = masked_positions,
            outputs["encoded_masked_text"] = encoded_masked_text
            outputs["encoded_masked_text_mask"] = masked_text_mask
        return outputs
    def copy_reference_policy(self,
                                timestep,
                                last_predictions: torch.LongTensor,
                                state: Dict[str, torch.Tensor],
                                target_tokens: Dict[str, torch.LongTensor],
                              ) -> torch.FloatTensor:
        targets = util.get_token_ids_from_text_field_tensors(target_tokens)
        seq_len = targets.size(1)
        
        batch_size = last_predictions.shape[0]
        if seq_len > timestep + 1:  # + 1 because timestep is an index, indexed at 0.
            # As we might be overriding  the next/predicted token/
            # We have to use the value corresponding to {t+1}^{th}
            # timestep.
            target_at_timesteps = targets[:, timestep + 1]
        else:
            # We have overshot the seq_len, so just repeat the
            # last token which is either _end_token or _pad_token.
            target_at_timesteps = targets[:, -1]

        # TODO: Add support to allow other types of reference policies.
        # target_logits: (batch_size, num_classes).
        # This tensor has 0 at targets and (near) -inf at other places.
        target_logits = (target_at_timesteps.new_zeros((batch_size, self._num_classes)) + 1e-45) \
                            .scatter_(dim=1,
                                      index=target_at_timesteps.unsqueeze(1),
                                      value=1.0).log()
        return target_logits, state
 def _evaluate(self, tokens, eval_mask):
     transformer_input = self._adapt_for_transformer(tokens)
     logits, *_ = self.model(**transformer_input)
     token_ids = util.get_token_ids_from_text_field_tensors(tokens)
     log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
     token_log_likelihood = log_probs[:,:-1].gather(-1, token_ids[:,1:].unsqueeze(-1)).squeeze(-1)
     suffix_log_likelihood = (eval_mask[:,1:] * token_log_likelihood).sum(-1)
     return token_log_likelihood, suffix_log_likelihood
    def forward(  # type: ignore
            self, text: TextFieldTensors) -> TaskOutput:

        mask = get_text_field_mask(text)
        contextual_embeddings = self.backbone.forward(text, mask)

        token_ids = get_token_ids_from_text_field_tensors(text)
        assert isinstance(contextual_embeddings, torch.Tensor)

        # Use token_ids to compute targets
        # targets are next token ids with respect to first token in the seq
        # e.g. token_ids [[1, 3, 5, 7],..[]], forward_targets=[[3,5,7],..]
        forward_targets = torch.zeros_like(token_ids)
        forward_targets[:, 0:-1] = token_ids[:, 1:]

        if self.bidirectional:
            backward_targets = torch.zeros_like(token_ids)
            backward_targets[:, 1:] = token_ids[:, 0:-1]
        else:
            backward_targets = None

        # add dropout
        contextual_embeddings_with_dropout = self._dropout(
            contextual_embeddings)

        # compute softmax loss
        try:
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout, forward_targets,
                backward_targets)
        except IndexError:
            raise IndexError(
                "Word token out of vocabulary boundaries, please check your vocab is correctly set"
                " or created before starting training.")

        num_targets = torch.sum((forward_targets > 0).long())

        if num_targets > 0:
            if self.bidirectional:
                average_loss = (0.5 * (forward_loss + backward_loss) /
                                num_targets.float())
            else:
                average_loss = forward_loss / num_targets.float()
        else:
            average_loss = torch.tensor(0.0).to(forward_targets.device)

        for metric in self.metrics.values():
            metric(average_loss)

        return TaskOutput(
            logits=None,
            probs=None,
            loss=average_loss,
            **{
                "lm_embeddings": contextual_embeddings,
                "mask": mask
            },
        )
Ejemplo n.º 6
0
    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors,
            labels: torch.IntTensor = None) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            From a `TextField`
        labels : `torch.IntTensor`, optional (default = `None`)
            From a `MultiLabelField`

        # Returns

        An output dictionary consisting of:

            - `logits` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                unnormalized log probabilities of the label.
            - `probs` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                probabilities of the label.
            - `loss` : (`torch.FloatTensor`, optional) :
                A scalar loss to be optimised.
        """

        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.sigmoid(logits)

        output_dict = {"logits": logits, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)
        if labels is not None:
            loss = self._loss(logits,
                              labels.float().view(-1, self._num_labels))
            output_dict["loss"] = loss
            # TODO (John): This shouldn't be necessary as __call__ of the metrics detaches these
            # tensors anyways?
            cloned_logits, cloned_labels = logits.clone(), labels.clone()
            self._micro_f1(cloned_logits, cloned_labels)
            self._macro_f1(cloned_logits, cloned_labels)

        return output_dict
    def rollin_policy(self,
                      timestep: int,
                      last_predictions: torch.LongTensor,
                      target_tokens: Dict[str, torch.Tensor] = None,
                      rollin_mode = None) -> torch.LongTensor:
        """ Roll-in policy to use.
            This takes in targets, timestep and last_predictions, and decide
            which to use for taking next step i.e., generating next token.
            What to do is decided by rolling mode. Options are
                - teacher_forcing,
                - learned,
                - mixed,

            By default the mode is mixed with scheduled_sampling_ratio=0.0. This 
            defaults to teacher_forcing. You can also explicitly run with teacher_forcing
            mode.

        Arguments:
            timestep {int} -- Current timestep decides which target token to use.
                              In case of teacher_forcing this is usually {t-1}^{th} timestep
                              for predicting t^{th} token.
            last_predictions {torch.LongTensor} -- {t-1}^th token predicted by the model.

        Keyword Arguments:
            targets {torch.LongTensor} -- Targets value if it is available. This will be
                                           available in training mode but not in inference mode. (default: {None})
            rollin_mode {str} -- Rollin mode. Options are
                                  teacher_forcing, learned, scheduled-sampling (default: {'teacher_forcing'})
        Returns:
            torch.LongTensor -- The method returns input token for predicting next token.
        """
        rollin_mode = rollin_mode or self._rollin_mode

        # For first timestep, you are passing start token, so don't do anything smart.
        if (timestep == 0 or
           # If no targets, no way to do teacher_forcing, so use your own predictions.
           target_tokens is None  or
           rollin_mode == 'learned'):
            # shape: (batch_size,)
            return last_predictions

        targets = util.get_token_ids_from_text_field_tensors(target_tokens)
        if rollin_mode == 'teacher_forcing':
            # shape: (batch_size,)
            input_choices = targets[:, timestep]
        elif rollin_mode == 'mixed':
            if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - self._scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]
        else:
            raise ConfigurationError(f"invalid configuration for rollin policy: {rollin_mode}")
        return input_choices
    def forward(
        self,
        sentence1: TextFieldTensors,
        sentence2: TextFieldTensors,
        label: torch.IntTensor,
    ) -> Dict[str, torch.Tensor]:
        embedded_sentence1 = self._text_field_embedder(sentence1)
        embedded_sentence2 = self._text_field_embedder(sentence2)
        sentence1_mask = get_text_field_mask(sentence1)
        sentence2_mask = get_text_field_mask(sentence2)

        if self._seq2seq_encoder:
            embedded_sentence1 = self._seq2seq_encoder(embedded_sentence1,
                                                       mask=sentence1_mask)
            embedded_sentence2 = self._seq2seq_encoder(embedded_sentence2,
                                                       mask=sentence2_mask)

        embedded_sentence1 = self._seq2vec_encoder(embedded_sentence1,
                                                   mask=sentence1_mask)
        embedded_sentence2 = self._seq2vec_encoder(embedded_sentence2,
                                                   mask=sentence2_mask)
        pair_vec = self._pair_vec_to_vec(embedded_sentence1,
                                         embedded_sentence2)
        if self._dropout:
            pair_vec = self._dropout(pair_vec)

        logits = self._classification_layer(pair_vec)
        probs = torch.softmax(logits, dim=-1)

        output_dict = {
            "logits":
            logits,
            "probs":
            probs,
            "sentence1_token_ids":
            get_token_ids_from_text_field_tensors(sentence1),
            "sentence2_token_ids":
            get_token_ids_from_text_field_tensors(sentence2),
        }
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)
        return output_dict
Ejemplo n.º 9
0
    def forward(self,
                text: TextFieldTensors) -> Dict[str, Any]:  # type: ignore

        mask = get_text_field_mask(text)
        contextual_embeddings = self.backbone.forward(text, mask)

        token_ids = get_token_ids_from_text_field_tensors(text)
        assert isinstance(contextual_embeddings, torch.Tensor)

        # Use token_ids to compute targets
        # targets are next token ids with respect to first token in the seq
        # e.g. token_ids [[1, 3, 5, 7],..[]], forward_targets=[[3,5,7],..]
        forward_targets = torch.zeros_like(token_ids)
        forward_targets[:, 0:-1] = token_ids[:, 1:]

        if self.bidirectional:
            backward_targets = torch.zeros_like(token_ids)
            backward_targets[:, 1:] = token_ids[:, 0:-1]
        else:
            backward_targets = None

        # add dropout
        contextual_embeddings_with_dropout = self._dropout(
            contextual_embeddings)

        # compute softmax loss
        try:
            forward_loss, backward_loss = self._compute_loss(
                contextual_embeddings_with_dropout, forward_targets,
                backward_targets)
        except IndexError:
            raise IndexError(
                "Word token out of vocabulary boundaries, please check your vocab is correctly set"
                " or created before starting training.")

        num_targets = torch.sum((forward_targets > 0).long())

        if num_targets > 0:
            if self.bidirectional:
                average_loss = (0.5 * (forward_loss + backward_loss) /
                                num_targets.float())
            else:
                average_loss = forward_loss / num_targets.float()
        else:
            average_loss = torch.tensor(0.0)

        for metric in self._metrics.get_dict(is_train=self.training).values():
            # Perplexity needs the value to be on the cpu
            metric(average_loss.to("cpu"))

        return dict(
            loss=average_loss,
            lm_embeddings=contextual_embeddings,
            mask=mask,
        )
Ejemplo n.º 10
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        label: torch.IntTensor = None,
        metadata: MetadataField = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            From a `TextField`
        label : `torch.IntTensor`, optional (default = `None`)
            From a `LabelField`

        # Returns

        An output dictionary consisting of:

            - `logits` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                unnormalized log probabilities of the label.
            - `probs` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                probabilities of the label.
            - `loss` : (`torch.FloatTensor`, optional) :
                A scalar loss to be optimised.
        """
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens)

        if self._seq2seq_encoder:
            embedded_text = self._seq2seq_encoder(embedded_text, mask=mask)

        embedded_text = self._seq2vec_encoder(embedded_text, mask=mask)

        if self._dropout:
            embedded_text = self._dropout(embedded_text)

        if self._feedforward is not None:
            embedded_text = self._feedforward(embedded_text)

        logits = self._classification_layer(embedded_text)
        probs = torch.nn.functional.softmax(logits, dim=-1)

        output_dict = {"logits": logits, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)
        if label is not None:
            loss = self._loss(logits, label.long().view(-1))
            output_dict["loss"] = loss
            self._accuracy(logits, label)

        return output_dict
Ejemplo n.º 11
0
    def forward(self, text: TextFieldTensors) -> Dict[str, torch.Tensor]:  # type: ignore
        bert_output = self._embed(text)

        outputs = {
            "encoded_text": bert_output['orig_embeddings'],
            "encoded_text_mask": bert_output['orig_mask'],
            "wordpiece_encoded_text": bert_output['wordpiece_embeddings'],
            "wordpiece_encoded_text_mask": bert_output['wordpiece_mask'],
            "token_ids": util.get_token_ids_from_text_field_tensors(text),
        }

        self._extend_with_masked_text(outputs, text)
        return outputs
Ejemplo n.º 12
0
  def _compute_rollin_loss_batch(self, 
          rollin_output_dict: Dict[str, torch.Tensor],
          state: Dict[str, torch.Tensor],
          target_tokens: Dict[str, torch.Tensor]) -> torch.FloatTensor:
    
    logits = rollin_output_dict['logits']
    targets = util.get_token_ids_from_text_field_tensors(target_tokens)
    # shape: (batch_size, num_decoding_steps)
    best_logits = logits[:, 0, :, :].squeeze(1)
    target_masks = util.get_text_field_mask(target_tokens)

    # Compute loss.
    loss_batch = self._get_cross_entropy_loss(best_logits, targets, target_masks)
    return loss_batch
Ejemplo n.º 13
0
    def compute_sentence_probs(self,
                               sequences_dict: Dict[str, torch.LongTensor],
                              ) -> torch.FloatTensor:
        """ Given a batch of tokens, compute the per-token log probability of sequences
            given the trained model.

        Arguments:
            sequences_dict {Dict[str, torch.LongTensor]} -- The sequences that needs to be scored.

        Returns:
            seq_probs {torch.FloatTensor} -- Probabilities of the sequence.
            seq_lens {torch.LongTensor} -- Length of the non padded sequence.
            per_step_seq_probs {torch.LongTensor} -- Probability of per prediction in a sequence
        """
        state = {}
        sequences = util.get_token_ids_from_text_field_tensors(sequences_dict)

        batch_size = sequences.size(0)
        seq_len = sequences.size(1)
        start_predictions = self._get_start_predictions(state,
                                                        sequences_dict,
                                                        batch_size)
        
        # We are now computing probability considering given the sequence,
        # So, we will use rollin_mode=teacher_forcing as we want to select
        # token from the sequences for which we need to compute the probability.
        rollin_output_dict = self.rollin(state={},
                                            start_predictions=start_predictions,
                                            rollin_steps=seq_len - 1,
                                            target_tokens=sequences_dict,
                                            rollin_mode='teacher_forcing',
                                        )

        step_log_probs = F.log_softmax(rollin_output_dict['logits'].squeeze(1), dim=-1)
        per_step_seq_probs = torch.gather(step_log_probs, 2,
                                          sequences[:,1:].unsqueeze(2)) \
                                            .squeeze(2)

        sequence_mask = util.get_text_field_mask(sequences_dict)
        per_step_seq_probs_summed = torch.sum(per_step_seq_probs * sequence_mask[:, 1:], dim=-1)
        non_batch_dims = tuple(range(1, len(sequence_mask.shape)))

        # shape : (batch_size,)
        sequence_mask_sum = sequence_mask[:, 1:].sum(dim=non_batch_dims)

        # (seq_probs, seq_lens, per_step_seq_probs)
        return torch.exp(per_step_seq_probs_summed/sequence_mask_sum), \
                sequence_mask_sum, \
                torch.exp(per_step_seq_probs)
 def forward(
         self,
         text: TextFieldTensors) -> Dict[str, torch.Tensor]:  # type: ignore
     if len(text) != 1:
         raise ValueError(
             "PretrainedTransformerBackbone is only compatible with using a single TokenIndexer"
         )
     text_inputs = next(iter(text.values()))
     mask = util.get_text_field_mask(text)
     encoded_text = self._embedder(**text_inputs)
     outputs = {"encoded_text": encoded_text, "encoded_text_mask": mask}
     if self._output_token_strings:
         outputs["token_ids"] = util.get_token_ids_from_text_field_tensors(
             text)
     return outputs
Ejemplo n.º 15
0
    def forward(  # type: ignore
            self, text: TextFieldTensors) -> TaskOutput:

        mask = get_text_field_mask(text)
        contextual_embeddings = self.backbone.forward(text, mask)
        # NOTE: @dvsrepo, Allennlp 1.0 includes a second features level that I'm not sure of understand.
        # Anyway, they proved a function to realize the target here (the function docstring clarifies the
        # real spaghetti inside indexer code references, :-)
        token_ids = get_token_ids_from_text_field_tensors(text)
        assert isinstance(contextual_embeddings, torch.Tensor)

        # Use token_ids to compute targets
        # targets are next token ids with respect to first token in the seq
        # e.g. token_ids [[1, 3, 5, 7],..[]], forward_targets=[[3,5,7],..]
        forward_targets = torch.zeros_like(token_ids)
        forward_targets[:, 0:-1] = token_ids[:, 1:]

        # add dropout
        contextual_embeddings_with_dropout = self._dropout(
            contextual_embeddings)

        # compute softmax loss
        try:
            forward_loss = self._compute_loss(
                contextual_embeddings_with_dropout, forward_targets)
        except IndexError:
            raise IndexError(
                "Word token out of vocabulary boundaries, please check your vocab is correctly set"
                " or created before starting training.")

        num_targets = torch.sum((forward_targets > 0).long())
        if num_targets > 0:
            average_loss = forward_loss / num_targets.float()
        else:
            average_loss = torch.tensor(0.0).to(forward_targets.device)

        for metric in self.metrics.values():
            metric(average_loss)

        return TaskOutput(logits=None,
                          probs=None,
                          loss=average_loss,
                          **{
                              "lm_embeddings": contextual_embeddings,
                              "mask": mask
                          })
Ejemplo n.º 16
0
    def forward(
        self,
        encoder_out: Dict[str, torch.LongTensor],
        target_tokens: TextFieldTensors = None,
    ) -> Dict[str, torch.Tensor]:
        state = encoder_out
        decoder_init_state = self._decoder_net.init_decoder_state(state)
        state.update(decoder_init_state)

        if target_tokens:
            state_forward_loss = (state if self.training else
                                  {k: v.clone()
                                   for k, v in state.items()})
            output_dict = self._forward_loss(state_forward_loss, target_tokens)
        else:
            output_dict = {}

        if not self.training:
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

            if target_tokens:
                targets = util.get_token_ids_from_text_field_tensors(
                    target_tokens)
                if self._tensor_based_metric is not None:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]

                    self._tensor_based_metric(  # type: ignore
                        best_predictions, targets)

                if self._token_based_metric is not None:
                    output_dict = self.post_process(output_dict)
                    predicted_tokens = output_dict["predicted_tokens"]

                    self._token_based_metric(  # type: ignore
                        predicted_tokens,
                        self.indices_to_tokens(targets[:, 1:]),
                    )

        return output_dict
Ejemplo n.º 17
0
    def forward(
        self,
        transactions: TextFieldTensors,
        label: Optional[torch.Tensor] = None,
        amounts: Optional[TextFieldTensors] = None,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        emb_out = self.get_transaction_embeddings(transactions)

        output_dict = self.forward_on_transaction_embeddings(
            transaction_embeddings=emb_out["transaction_embeddings"],
            mask=emb_out["mask"],
            label=label,
            amounts=amounts,
        )

        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(transactions)

        return output_dict
Ejemplo n.º 18
0
    def _get_start_predictions(self,
              state: Dict[str, torch.Tensor],
              target_tokens: Dict[str, torch.LongTensor] = None,
              generation_batch_size:int = None) ->  torch.LongTensor:

        if self._seq2seq_mode:
           source_mask = state["source_mask"]
           batch_size = source_mask.size()[0]
        elif target_tokens:
            targets = util.get_token_ids_from_text_field_tensors(target_tokens)
            batch_size = targets.size(0)
        else:
            batch_size = generation_batch_size

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        return torch.zeros((batch_size,),
                            dtype=torch.long,
                            device=self.current_device) \
                    .fill_(self._start_index)
Ejemplo n.º 19
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        mask_positions: torch.BoolTensor,
        target_ids: TextFieldTensors = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            The output of `TextField.as_tensor()` for a batch of sentences.
        mask_positions : `torch.LongTensor`
            The positions in `tokens` that correspond to [MASK] tokens that we should try to fill
            in.  Shape should be (batch_size, num_masks).
        target_ids : `TextFieldTensors`
            This is a list of token ids that correspond to the mask positions we're trying to fill.
            It is the output of a `TextField`, purely for convenience, so we can handle wordpiece
            tokenizers and such without having to do crazy things in the dataset reader.  We assume
            that there is exactly one entry in the dictionary, and that it has a shape identical to
            `mask_positions` - one target token per mask position.
        """

        targets = None
        if target_ids is not None:
            targets = util.get_token_ids_from_text_field_tensors(target_ids)
        mask_positions = mask_positions.squeeze(-1)
        batch_size, num_masks = mask_positions.size()
        if targets is not None and targets.size() != mask_positions.size():
            raise ValueError(
                f"Number of targets ({targets.size()}) and number of masks "
                f"({mask_positions.size()}) are not equal")

        # Shape: (batch_size, num_tokens, embedding_dim)
        embeddings = self._text_field_embedder(tokens)

        # Shape: (batch_size, num_tokens, encoding_dim)
        if self._contextualizer:
            mask = util.get_text_field_mask(embeddings)
            contextual_embeddings = self._contextualizer(embeddings, mask)
        else:
            contextual_embeddings = embeddings

        # Does advanced indexing to get the embeddings of just the mask positions, which is what
        # we're trying to predict.
        batch_index = torch.arange(0, batch_size).long().unsqueeze(1)
        mask_embeddings = contextual_embeddings[batch_index, mask_positions]

        target_logits = self._language_model_head(
            self._dropout(mask_embeddings))

        vocab_size = target_logits.size(-1)
        probs = torch.nn.functional.softmax(target_logits, dim=-1)
        k = min(vocab_size,
                5)  # min here largely because tests use small vocab
        top_probs, top_indices = probs.topk(k=k, dim=-1)

        output_dict = {"probabilities": top_probs, "top_indices": top_indices}

        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)

        if targets is not None:
            target_logits = target_logits.view(batch_size * num_masks,
                                               vocab_size)
            targets = targets.view(batch_size * num_masks)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        return output_dict
Ejemplo n.º 20
0
    def rollout(self,
                state: Dict[str, torch.Tensor],
                start_predictions: torch.LongTensor,
                rollout_steps: int,
                beam_size: int = None,
                per_node_beam_size: int = None,
                target_tokens: Dict[str, torch.LongTensor] = None,
                sampled: bool = True,
                truncate_at_end_all: bool = True,
                # shape (prediction_prefixes): (batch_size, prefix_length)
                prediction_prefixes: torch.LongTensor = None,
                target_prefixes: torch.LongTensor = None,
                rollout_mixing_func: RolloutMixingProbFuncType = None,
                reference_policy_type:str = "copy",
                rollout_mode: str = None,
               ):
        state['rollout_params'] = {}
        if reference_policy_type == 'oracle':
            reference_policy = partial(self.oracle_reference_policy,
                                        token_to_idx=self._vocab._token_to_index['target_tokens'],
                                        idx_to_token=self._vocab._index_to_token['target_tokens'],
                                       )
            num_steps_to_take = rollout_steps
            state['rollout_params']['rollout_prefixes'] = prediction_prefixes
        else:
            reference_policy = partial(self.copy_reference_policy,
                                        target_tokens=target_tokens)           
            num_steps_to_take = rollout_steps

        rollout_policy = partial(self.rollout_policy,
                                    rollout_mode=rollout_mode,
                                    rollout_mixing_func=rollout_mixing_func,
                                    reference_policy=reference_policy,
                                )
        rolling_policy=partial(self.take_step,
                               rollout_policy=rollout_policy)

        # shape (step_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        # shape (logits): (batch_size, beam_size, num_decoding_steps, num_classes)
        step_predictions, log_probabilities, logits = \
                    self._beam_search.search(start_predictions,
                                                state,
                                                rolling_policy,
                                                max_steps=num_steps_to_take,
                                                beam_size=beam_size,
                                                per_node_beam_size=per_node_beam_size,
                                                sampled=sampled,
                                                truncate_at_end_all=truncate_at_end_all)

        logits = torch.cat(logits, dim=2)
        
        # Concatenate the start tokens to the predictions.They are not
        # added to the predictions by default.
        batch_size, beam_size, _ = step_predictions.shape

        start_prediction_length = start_predictions.size(0)
        step_predictions = torch.cat([start_predictions.unsqueeze(1) \
                                        .expand(batch_size, beam_size) \
                                        .reshape(batch_size, beam_size, 1),
                                        step_predictions],
                                        dim=-1)

        # There might be some predictions which might have been made by
        # rollin policy. If passed, concatenate them here.
        if prediction_prefixes is not None:
            prefixes_length = prediction_prefixes.size(1)
            step_predictions = torch.cat([prediction_prefixes.unsqueeze(1)\
                                            .expand(batch_size, beam_size, prefixes_length), 
                                         step_predictions],
                                         dim=-1)

        step_prediction_masks = self._get_mask(step_predictions \
                                                .reshape(batch_size * beam_size, -1)) \
                                        .reshape(batch_size, beam_size, -1)

        output_dict = {
            "predictions": step_predictions,
            "prediction_masks": step_prediction_masks,
            "logits": logits,
            "class_log_probabilities": log_probabilities,
        }

        step_targets = None
        step_target_masks = None
        if target_tokens is not None:
            step_targets = util.get_token_ids_from_text_field_tensors(target_tokens)
            if target_prefixes is not None:
                prefixes_length = target_prefixes.size(1)
                step_targets = torch.cat([target_prefixes, step_targets], dim=-1)

            step_target_masks = util.get_text_field_mask({'tokens': {'tokens': step_targets}})
            
            output_dict.update({
                "targets": step_targets,
                "target_masks": step_target_masks,
            })
        return output_dict
Ejemplo n.º 21
0
    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        answer_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question_with_context : Dict[str, torch.LongTensor]
            From a ``TextField``. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and
            [SEP]) has type id 1.
        context_span : ``torch.IntTensor``
            From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come.
        answer_span : ``torch.IntTensor``, optional
            From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the ``metadata`` list should be the
            batch size, and each dictionary should have the keys ``id``, ``question``, ``context``,
            ``question_tokens``, ``context_tokens``, and ``answers``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        best_span_scores : torch.FloatTensor
            The score for each of the best spans.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._text_field_embedder(question_with_context)
        logits = self._linear_layer(embedded_question)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        span_start_logits = span_start_logits.squeeze(-1)
        span_end_logits = span_end_logits.squeeze(-1)

        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context))
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start:end + 1] = 1

        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       possible_answer_mask,
                                                       -1e32)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     possible_answer_mask,
                                                     -1e32)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
        best_spans = get_best_span(span_start_logits, span_end_logits)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_spans,
            "best_span_scores": best_span_scores,
        }

        # Compute the loss for training.
        if answer_span is not None:
            span_start = answer_span[:, 0]
            span_end = answer_span[:, 1]
            span_mask = span_start != -1
            self._span_accuracy(best_spans, answer_span,
                                span_mask.unsqueeze(-1).expand_as(best_spans))

            start_loss = cross_entropy(span_start_logits,
                                       span_start,
                                       ignore_index=-1)
            if torch.any(start_loss > 1e9):
                logger.critical("Start loss too high (%r)", start_loss)
                logger.critical("span_start_logits: %r", span_start_logits)
                logger.critical("span_start: %r", span_start)
                assert False

            end_loss = cross_entropy(span_end_logits,
                                     span_end,
                                     ignore_index=-1)
            if torch.any(end_loss > 1e9):
                logger.critical("End loss too high (%r)", end_loss)
                logger.critical("span_end_logits: %r", span_end_logits)
                logger.critical("span_end: %r", span_end)
                assert False

            loss = (start_loss + end_loss) / 2

            self._span_start_accuracy(span_start_logits, span_start, span_mask)
            self._span_end_accuracy(span_end_logits, span_end, span_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            best_spans = best_spans.detach().cpu().numpy()

            output_dict["best_span_str"] = []
            context_tokens = []
            for metadata_entry, best_span in zip(metadata, best_spans):
                context_tokens_for_question = metadata_entry["context_tokens"]
                context_tokens.append(context_tokens_for_question)

                best_span -= 1 + len(metadata_entry["question_tokens"]) + 2
                assert np.all(best_span >= 0)

                predicted_start, predicted_end = tuple(best_span)

                while (predicted_start >= 0
                       and context_tokens_for_question[predicted_start].idx is
                       None):
                    predicted_start -= 1
                if predicted_start < 0:
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index "
                        f"'{best_span[0]}' to an offset in the original text.")
                    character_start = 0
                else:
                    character_start = context_tokens_for_question[
                        predicted_start].idx

                while (predicted_end < len(context_tokens_for_question) and
                       context_tokens_for_question[predicted_end].idx is None):
                    predicted_end += 1
                if predicted_end >= len(context_tokens_for_question):
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index "
                        f"'{best_span[1]}' to an offset in the original text.")
                    character_end = len(metadata_entry["context"])
                else:
                    end_token = context_tokens_for_question[predicted_end]
                    character_end = end_token.idx + len(
                        sanitize_wordpiece(end_token.text))

                best_span_string = metadata_entry["context"][
                    character_start:character_end]
                output_dict["best_span_str"].append(best_span_string)

                answers = metadata_entry.get("answers")
                if len(answers) > 0:
                    self._per_instance_metrics(best_span_string, answers)
            output_dict["context_tokens"] = context_tokens
        return output_dict
Ejemplo n.º 22
0
    def forward(self,  # type: ignore
                encoder_out: Dict[str, torch.LongTensor] = {},
                target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.

        source_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        output_dict: Dict[str, torch.Tensor] = {}
        state: Dict[str, torch.Tensor] = {}
        decoder_init_state: Dict[str, torch.Tensor] = {}

        state.update(copy.copy(encoder_out))
        # In Seq2Seq setting, we will encode the source sequence,
        # and init the state object with encoder output and decoder
        # cell will use these encoder outputs for attention/initing
        # the decoder states.
        if self._seq2seq_mode:
            decoder_init_state = \
                        self._decoder_net.init_decoder_state(state)
            state.update(decoder_init_state)

       # Initialize target predictions with the start index.
        # shape: (batch_size,)
        start_predictions: torch.LongTensor = \
                self._get_start_predictions(state,
                                        target_tokens,
                                        self._generation_batch_size)
        
        # In case we have target_tokens, roll-in and roll-out
        # only till those many steps, otherwise we roll-out for
        # `self._max_decoding_steps`.
        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets: torch.LongTensor = \
                    util.get_token_ids_from_text_field_tensors(target_tokens)

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps: int = target_sequence_length - 1
        else:
            num_decoding_steps: int = self._max_decoding_steps

        if target_tokens:
            decoder_output_dict, rollin_dict, rollout_dict_iter = \
                                        self._forward_loop(
                                                state=state,
                                                start_predictions=start_predictions,
                                                num_decoding_steps=num_decoding_steps,
                                                target_tokens=target_tokens)

            output_dict.update(decoder_output_dict)
            predictions = decoder_output_dict['predictions']
            predicted_tokens = self._decode_tokens(predictions,
                                                    vocab_namespace=self._target_namespace,
                                                    truncate=True)
            output_dict["decoded_predictions"] = predicted_tokens

            decoded_targets = self._decode_tokens(targets,
                                    vocab_namespace=self._target_namespace,
                                    truncate=True)
            output_dict["decoded_targets"] = decoded_targets

            output_dict.update(self._loss_criterion(
                                            rollin_output_dict=rollin_dict, 
                                            rollout_output_dict_iter=rollout_dict_iter, 
                                            state=state, 
                                            target_tokens=target_tokens))

            mle_loss_output = self._mle_loss(
                                    rollin_output_dict=rollin_dict, 
                                    rollout_output_dict_iter=rollout_dict_iter, 
                                    state=state, 
                                    target_tokens=target_tokens)

            mle_loss = mle_loss_output['loss']
            self._perplexity(mle_loss)

        if not self.training:
            # While validating or testing we need to roll out the learned policy and the output
            # of this rollout is used to compute the secondary metrics
            # like BLEU.
            state: Dict[str, torch.Tensor] = {}
            state.update(copy.copy(encoder_out))
            state.update(decoder_init_state)

            rollout_output_dict = self.rollout(state,
                                        start_predictions,
                                        rollout_steps=num_decoding_steps,
                                        rollout_mode='learned',
                                        sampled=self._sample_rollouts,
                                        beam_size=self._eval_beam_size,
                                        # TODO #6 (Kushal): Add a reason why truncate_at_end_all is False here.
                                        truncate_at_end_all=False)

            output_dict.update(rollout_output_dict)

            predictions = decoder_output_dict['predictions']
            predicted_tokens = self._decode_tokens(predictions,
                                                vocab_namespace=self._target_namespace,
                                                truncate=True)
            output_dict["decoded_predictions"] = predicted_tokens
            decoded_predictions = [predictions[0] \
                                    for predictions in output_dict["decoded_predictions"]]


            # shape (predictions): (batch_size, beam_size, num_decoding_steps)
            predictions = rollout_output_dict['predictions']

            # shape (best_predictions): (batch_size, num_decoding_steps)
            best_predictions = predictions[:, 0, :]

            if target_tokens:
                targets = util.get_token_ids_from_text_field_tensors(target_tokens)
                target_mask = util.get_text_field_mask(target_tokens)
                decoded_targets = self._decode_tokens(targets,
                                        vocab_namespace=self._target_namespace,
                                        truncate=True)

                # TODO #3 (Kushal): Maybe abstract out these losses and use loss_metric like AllenNLP uses.
                if self._bleu and target_tokens:
                    self._bleu(best_predictions, targets)

                if  self._hamming and target_tokens:
                    self._hamming(best_predictions, targets, target_mask)

                if self._tensor_based_metric is not None:
                    self._tensor_based_metric(  # type: ignore
                        predictions=best_predictions,
                        gold_targets=targets,
                    )
                if self._tensor_based_metric_mask is not None:
                    self._tensor_based_metric_mask(  # type: ignore
                        predictions=best_predictions,
                        gold_targets=targets,
                        mask=~target_mask,
                    )

                if self._token_based_metric is not None:
                    self._token_based_metric(  # type: ignore
                            predictions=decoded_predictions, 
                            gold_targets=decoded_targets,
                        )
        return output_dict
Ejemplo n.º 23
0
    def rollin_parallel(self, 
                        state: Dict[str, torch.Tensor],
                        start_predictions: torch.LongTensor,
                        rollin_steps: int,
                        target_tokens: Dict[str, torch.LongTensor] = None,
                        beam_size: int = 1,
                        per_node_beam_size: int = None,
                        sampled: bool = False,
                        truncate_at_end_all: bool = False,
                        rollin_mode: str = None,
                    ):
        assert self._decoder_net.decodes_parallel, \
            "Rollin Parallel is only applicable for transformer style decoders" + \
            "that decode whole sequence in parallel."
        
        assert not rollin_mode or rollin_mode == "learned", \
            "Parallel Decoding only works when following " + \
            "teacher forcing rollin policy (rollin_mode='learned')."

        assert self._scheduled_sampling_ratio == 0, \
            "For learned rollin mode, scheduled sampling ratio should always be 0."

        self.training_iteration += 1

        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (batch_size, max_target_sequence_length)
        targets = util.get_token_ids_from_text_field_tensors(target_tokens)

        # Prepare embeddings for targets. They will be used as gold embeddings during decoder training
        # shape: (batch_size, max_target_sequence_length, embedding_dim)
        target_embedding = self.target_embedder(targets)

        # shape: (batch_size, max_target_batch_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)

        _, decoder_output = self._decoder_net(
            previous_state=state,
            previous_steps_predictions=target_embedding[:, :-1, :],
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_mask=target_mask[:, :-1],
        )

        # shape: (group_size, max_target_sequence_length, num_classes)
        logits = self._output_projection_layer(decoder_output)

        # Unsqueeze logit to add beam size dimension.
        logits = logits.unsqueeze(dim=1)

        log_probabilities, step_predictions = torch.max(logits, dim=-1)

        return {
            "predictions": step_predictions,
            "logits": logits,
            "class_log_probabilities": log_probabilities,
        }
Ejemplo n.º 24
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        frame_indicator: torch.Tensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
        frame_tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        frame_indicator: torch.LongTensor, required.
            An integer ``SequenceFeatureField`` representation of the position of the frame
            in the sentence. This should have shape (batch_size, num_tokens). Similar to verb_indicator,
            but handles bert wordpiece tokenizer by cosnidering a frame only the first subtoken.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        frame_tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the gold frames
            of shape ``(batch_size, num_tokens)``
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        input_ids = util.get_token_ids_from_text_field_tensors(tokens)
        bert_embeddings, _ = self.transformer(
            input_ids=input_ids,
            token_type_ids=verb_indicator,
            attention_mask=mask,
            return_dict=False,
        )
        # extract embeddings
        embedded_text_input = self.embedding_dropout(bert_embeddings)
        frame_embeddings = embedded_text_input[frame_indicator == 1]
        # get sizes
        batch_size, sequence_length, _ = embedded_text_input.size()
        # outputs
        logits = self.tag_projection_layer(embedded_text_input)
        frame_logits = self.frame_projection_layer(frame_embeddings)

        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])

        frame_probabilities = F.softmax(frame_logits, dim=-1)
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict = {
            "logits": logits,
            "frame_logits": frame_logits,
            "class_probabilities": class_probabilities,
            "frame_probabilities": frame_probabilities,
            "mask": mask,
        }
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        lemmas = [l for x in metadata for l in x["lemmas"]]
        output_dict["words"] = list(words)
        output_dict["lemma"] = list(lemmas)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            # compute role loss
            role_loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            # compute frame loss
            frame_tags_filtered = frame_tags[frame_indicator == 1]
            frame_loss = self.frame_criterion(frame_logits,
                                              frame_tags_filtered)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            self.f1_frame_metric(frame_logits, frame_tags_filtered)
            output_dict["frame_loss"] = frame_loss
            output_dict["role_loss"] = role_loss
            output_dict["loss"] = (role_loss + frame_loss) / 2
        return output_dict
Ejemplo n.º 25
0
    def forward(  # type: ignore
            self,
            tokens: TextFieldTensors = None,
            label: torch.IntTensor = None,
            **metadata) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        tokens : `TextFieldTensors`
            From a `TextField`
        label : `torch.IntTensor`, optional (default = `None`)
            From a `LabelField`

        # Returns

        An output dictionary consisting of:

            - `logits` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                unnormalized log probabilities of the label.
            - `probs` (`torch.FloatTensor`) :
                A tensor of shape `(batch_size, num_labels)` representing
                probabilities of the label.
            - `loss` : (`torch.FloatTensor`, optional) :
                A scalar loss to be optimised.
        """
        if tokens is None:
            tokens = metadata.pop("sentence")

        token_embeddings = self._text_field_embedder(tokens)

        mask = get_text_field_mask(tokens)

        text_embeddings = self._seq2vec_encoder(token_embeddings, mask=mask)

        if self._dropout:
            text_embeddings = self._dropout(text_embeddings)

        if self._feedforward is not None:
            text_embeddings = self._feedforward(text_embeddings)

        logits = self._classification_layer(text_embeddings)
        output_dict = {"logits": logits}
        if self._num_labels > 1:
            probs = torch.nn.functional.softmax(logits, dim=-1)
            output_dict["probs"] = probs

        for key in ["idx", "pair_id"]:
            output_dict[key] = metadata.get(key, [None] * len(logits))
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)

        if label is not None:
            if self._num_labels > 1:
                loss = self._loss(logits, label.long().view(-1))
                output_dict["loss"] = loss

                assert self._accuracy is not None
                self._accuracy(logits, label)

                # Shape: (batch_size,)
                predictions = logits.argmax(axis=-1)
                # Shape: (batch_size,)
                references = label

            else:
                # Shape: (batch_size,)
                predictions = logits.squeeze(-1)
                # Shape: (batch_size,)
                references = label
                loss = self._loss(logits.squeeze(-1), label)
                output_dict["loss"] = loss

            for metric in self._metrics:
                metric(predictions, references)

        return output_dict
Ejemplo n.º 26
0
    def forward(  # type: ignore
        self, tokens: TextFieldTensors, target_ids: TextFieldTensors = None
    ) -> Dict[str, torch.Tensor]:
        """
        Run a forward pass of the model, returning an output tensor dictionary with
        the following fields:

        - `"probabilities"`: a tensor of shape `(batch_size, n_best)` representing
          the probabilities of the predicted tokens, where `n_best`
          is either `self._n_best` or `beam_size` if using beam search.
        - `"top_indices"`: a tensor of shape `(batch_size, n_best, num_predicted_tokens)`
          containing the IDs of the predicted tokens, where `num_predicted_tokens` is just
          1 unless using beam search, in which case it depends on the parameters of the beam search.
        - `"token_ids"`: a tensor of shape `(batch_size, num_input_tokens)` containing the IDs
          of the input tokens.
        - `"loss"` (optional): the loss of the batch, only given if `target_ids` is not `None`.

        """
        output_dict = {
            "token_ids": util.get_token_ids_from_text_field_tensors(tokens),
        }

        # Shape: (batch_size, vocab_size)
        target_logits = self._next_token_scores(tokens)

        # Compute loss.
        if target_ids is not None:
            batch_size, vocab_size = target_logits.size()
            tmp = util.get_token_ids_from_text_field_tensors(target_ids)
            # In some scenarios, target_ids might be a topk list of token ids (e.g. sorted by probabilities).
            # Therefore, we need to make sure only one token per batch
            # Assume: first token in each batch is the most desirable one (e.g. highest probability)
            tmp = tmp[:, 0] if len(tmp.shape) == 2 else tmp
            assert len(tmp.shape) <= 2
            targets = tmp.view(batch_size)
            loss = torch.nn.functional.cross_entropy(target_logits, targets)
            self._perplexity(loss)
            output_dict["loss"] = loss

        if self._beam_search_generator is not None:
            # Dummy start predictions.
            # Shape: (batch_size,)
            start_predictions = torch.zeros(
                target_logits.size()[0], device=target_logits.device, dtype=torch.int
            )

            state = self._beam_search_generator.get_step_state(tokens)

            # Put this in here to avoid having to re-compute on the first step of beam search.
            state["start_target_logits"] = target_logits

            # Shape (top_indices): (batch_size, beam_size, num_predicted_tokens)
            # Shape (top_log_probs): (batch_size, beam_size)
            top_indices, top_log_probs = self._beam_search_generator.search(
                start_predictions, state, self._beam_search_step
            )

            # Shape: (batch_size, beam_size)
            top_probs = top_log_probs.exp()
        else:
            # Shape: (batch_size, vocab_size)
            probs = torch.nn.functional.softmax(target_logits, dim=-1)

            # Shape (both): (batch_size, n_best)
            # min here largely because tests use small vocab
            top_probs, top_indices = probs.topk(k=min(target_logits.size(-1), self._n_best), dim=-1)

            # Shape: (batch_size, n_best, 1)
            top_indices = top_indices.unsqueeze(-1)

        output_dict["top_indices"] = top_indices
        output_dict["probabilities"] = top_probs

        return output_dict
Ejemplo n.º 27
0
    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        cls_index: torch.LongTensor = None,
        answer_span: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        question_with_context : `Dict[str, torch.LongTensor]`
            From a `TextField`. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including
            `[CLS]` and `[SEP]`) has type id 1.

        context_span : `torch.IntTensor`
            From a `SpanField`. This marks the span of word pieces in `question` from which answers can come.

        cls_index : `torch.LongTensor`, optional
            A tensor of shape `(batch_size,)` that provides the index of the `[CLS]` token
            in the `question_with_context` for each instance.

            This is needed because the `[CLS]` token is used to indicate that the question
            is impossible.

            If this is `None`, it's assumed that the `[CLS]` token is at index 0 for each instance
            in the batch.

        answer_span : `torch.IntTensor`, optional
            From a `SpanField`. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.

        metadata : `List[Dict[str, Any]]`, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the `metadata` list should be the
            batch size, and each dictionary should have the keys `id`, `question`, `context`,
            `question_tokens`, `context_tokens`, and `answers`.

        # Returns

        `Dict[str, torch.Tensor]` :
            An output dictionary with the following fields:

            - span_start_logits (`torch.FloatTensor`) :
              A tensor of shape `(batch_size, passage_length)` representing unnormalized log
              probabilities of the span start position.
            - span_end_logits (`torch.FloatTensor`) :
              A tensor of shape `(batch_size, passage_length)` representing unnormalized log
              probabilities of the span end position (inclusive).
            - best_span_scores (`torch.FloatTensor`) :
              The score for each of the best spans.
            - loss (`torch.FloatTensor`, optional) :
              A scalar loss to be optimised, evaluated against `answer_span`.
            - best_span (`torch.IntTensor`, optional) :
              Provided when not in train mode and sufficient metadata given for the instance.
              The result of a constrained inference over `span_start_logits` and
              `span_end_logits` to find the most probable span.  Shape is `(batch_size, 2)`
              and each offset is a token index, unless the best span for an instance
              was predicted to be the `[CLS]` token, in which case the span will be (-1, -1).
            - best_span_str (`List[str]`, optional) :
              Provided when not in train mode and sufficient metadata given for the instance.
              This is the string from the original passage that the model thinks is the best answer
              to the question.

        """
        embedded_question = self._text_field_embedder(question_with_context)
        # shape: (batch_size, sequence_length, 2)
        logits = self._linear_layer(embedded_question)
        # shape: (batch_size, sequence_length, 1)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        # shape: (batch_size, sequence_length)
        span_start_logits = span_start_logits.squeeze(-1)
        # shape: (batch_size, sequence_length)
        span_end_logits = span_end_logits.squeeze(-1)

        # Create a mask for `question_with_context` to mask out tokens that are not part
        # of the context.
        # shape: (batch_size, sequence_length)
        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context),
            dtype=torch.bool)
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start:end + 1] = True
            # Also unmask the [CLS] token since that token is used to indicate that
            # the question is impossible.
            possible_answer_mask[
                i, 0 if cls_index is None else cls_index[i]] = True

        # Replace the masked values with a very negative constant since we're in log-space.
        # shape: (batch_size, sequence_length)
        span_start_logits = replace_masked_values_with_big_negative_number(
            span_start_logits, possible_answer_mask)
        # shape: (batch_size, sequence_length)
        span_end_logits = replace_masked_values_with_big_negative_number(
            span_end_logits, possible_answer_mask)

        # Now calculate the best span.
        # shape: (batch_size, 2)
        best_spans = get_best_span(span_start_logits, span_end_logits)

        # Sum the span start score with the span end score to get an overall score for the span.
        # shape: (batch_size,)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span_scores": best_span_scores,
        }

        # Compute the loss.
        if answer_span is not None:
            output_dict["loss"] = self._evaluate_span(best_spans,
                                                      span_start_logits,
                                                      span_end_logits,
                                                      answer_span)

        # Gather the string of the best span and compute the EM and F1 against the gold span,
        # if given.
        if not self.training and metadata is not None:
            (
                output_dict["best_span_str"],
                output_dict["best_span"],
            ) = self._collect_best_span_strings(best_spans, context_span,
                                                metadata, cls_index)

        return output_dict
Ejemplo n.º 28
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,  # batch * words
        options: TextFieldTensors,  # batch * num_options * words
        labels: torch.IntTensor = None  # batch * num_options
    ) -> Dict[str, torch.Tensor]:
        embedded_text = self._text_field_embedder(tokens)
        mask = get_text_field_mask(tokens).long()

        embedded_options = self._text_field_embedder(
            options, num_wrapping_dims=1)  # options_mask.dim() - 2
        options_mask = get_text_field_mask(options).long()

        if self._dropout:
            embedded_text = self._dropout(embedded_text)
            embedded_options = self._dropout(embedded_options)
        """
        This isn't exactly a 'hack', but it's definitely not the most efficient way to do it.
        Our matcher expects a single (query, document) pair, but we have (query, [d_0, ..., d_n]).
        To get around this, we expand the query embeddings to create these pairs, and then
        flatten both into the 3D tensor [batch*num_options, words, dim] expected by the matcher. 
        The expansion does this:

        [
            (q_0, [d_{0,0}, ..., d_{0,n}]), 
            (q_1, [d_{1,0}, ..., d_{1,n}])
        ]
        =>
        [
            [ (q_0, d_{0,0}), ..., (q_0, d_{0,n}) ],
            [ (q_1, d_{1,0}), ..., (q_1, d_{1,n}) ]
        ]

        Which we then flatten along the batch dimension. It would likely be more efficient
        to rewrite the matrix multiplications in the relevance matchers, but this is a more general solution.
        """

        embedded_text = embedded_text.unsqueeze(1).expand(
            -1, embedded_options.size(1), -1,
            -1)  # [batch, num_options, words, dim]
        mask = mask.unsqueeze(1).expand(-1, embedded_options.size(1), -1)

        scores = self._relevance_matcher(embedded_text, embedded_options, mask,
                                         options_mask).squeeze(-1)
        probs = torch.sigmoid(scores)

        output_dict = {"logits": scores, "probs": probs}
        output_dict["token_ids"] = util.get_token_ids_from_text_field_tensors(
            tokens)
        if labels is not None:
            label_mask = (labels != -1)

            self._mrr(probs, labels, label_mask)
            self._ndcg(probs, labels, label_mask)

            probs = probs.view(-1)
            labels = labels.view(-1)
            label_mask = label_mask.view(-1)

            self._auc(probs, labels.ge(0.5).long(), label_mask)

            loss = self._loss(probs, labels)
            output_dict["loss"] = loss.masked_fill(~label_mask,
                                                   0).sum() / label_mask.sum()

        return output_dict
Ejemplo n.º 29
0
    def attack_from_json(
        self,
        inputs: JsonDict,
        input_field_to_attack: str = "tokens",
        grad_input_field: str = "grad_input_1",
        ignore_tokens: List[str] = None,
        target: JsonDict = None,
    ) -> JsonDict:
        """
        Replaces one token at a time from the input until the model's prediction changes.
        ``input_field_to_attack`` is for example ``tokens``, it says what the input field is
        called.  ``grad_input_field`` is for example ``grad_input_1``, which is a key into a grads
        dictionary.

        The method computes the gradient w.r.t. the tokens, finds the token with the maximum
        gradient (by L2 norm), and replaces it with another token based on the first-order Taylor
        approximation of the loss.  This process is iteratively repeated until the prediction
        changes.  Once a token is replaced, it is not flipped again.

        # Parameters

        inputs : ``JsonDict``
            The model inputs, the same as what is passed to a ``Predictor``.
        input_field_to_attack : ``str``, optional (default='tokens')
            The field that has the tokens that we're going to be flipping.  This must be a
            ``TextField``.
        grad_input_field : ``str``, optional (default='grad_input_1')
            If there is more than one field that gets embedded in your model (e.g., a question and
            a passage, or a premise and a hypothesis), this tells us the key to use to get the
            correct gradients.  This selects from the output of :func:`Predictor.get_gradients`.
        ignore_tokens : ``List[str]``, optional (default=DEFAULT_IGNORE_TOKENS)
            These tokens will not be flipped.  The default list includes some simple punctuation,
            OOV and padding tokens, and common control tokens for BERT, etc.
        target : ``JsonDict``, optional (default=None)
            If given, this will be a `targeted` hotflip attack, where instead of just trying to
            change a model's prediction from what it current is predicting, we try to change it to
            a `specific` target value.  This is a ``JsonDict`` because it needs to specify the
            field name and target value.  For example, for a masked LM, this would be something
            like ``{"words": ["she"]}``, because ``"words"`` is the field name, there is one mask
            token (hence the list of length one), and we want to change the prediction from
            whatever it was to ``"she"``.
        """
        if self.embedding_matrix is None:
            self.initialize()
        ignore_tokens = DEFAULT_IGNORE_TOKENS if ignore_tokens is None else ignore_tokens

        # If `target` is `None`, we move away from the current prediction, otherwise we move
        # _towards_ the target.
        sign = -1 if target is None else 1
        instance = self.predictor._json_to_instance(inputs)
        if target is None:
            output_dict = self.predictor._model.forward_on_instance(instance)
        else:
            output_dict = target

        # This now holds the predictions that we want to change (either away from or towards,
        # depending on whether `target` was passed).  We'll use this in the loop below to check for
        # when we've met our stopping criterion.
        original_instances = self.predictor.predictions_to_labeled_instances(
            instance, output_dict)

        # This is just for ease of access in the UI, so we know the original tokens.  It's not used
        # in the logic below.
        original_text_field: TextField = original_instances[0][  # type: ignore
            input_field_to_attack]
        original_tokens = deepcopy(original_text_field.tokens)

        final_tokens = []
        # `original_instances` is a list because there might be several different predictions that
        # we're trying to attack (e.g., all of the NER tags for an input sentence).  We attack them
        # one at a time.
        for instance in original_instances:
            # Gets a list of the fields that we want to check to see if they change.
            fields_to_compare = utils.get_fields_to_compare(
                inputs, instance, input_field_to_attack)

            # We'll be modifying the tokens in this text field below, and grabbing the modified
            # list after the `while` loop.
            text_field: TextField = instance[
                input_field_to_attack]  # type: ignore

            # Because we can save computation by getting grads and outputs at the same time, we do
            # them together at the end of the loop, even though we use grads at the beginning and
            # outputs at the end.  This is our initial gradient for the beginning of the loop.  The
            # output can be ignored here.
            grads, outputs = self.predictor.get_gradients([instance])

            # Ignore any token that is in the ignore_tokens list by setting the token to already
            # flipped.
            flipped: List[int] = []
            for index, token in enumerate(text_field.tokens):
                if token.text in ignore_tokens:
                    flipped.append(index)
            if "clusters" in outputs:
                # Coref unfortunately needs a special case here.  We don't want to flip words in
                # the same predicted coref cluster, but we can't really specify a list of tokens,
                # because, e.g., "he" could show up in several different clusters.
                # TODO(mattg): perhaps there's a way to get `predictions_to_labeled_instances` to
                # return the set of tokens that shouldn't be changed for each instance?  E.g., you
                # could imagine setting a field on the `Token` object, that we could then read
                # here...
                for cluster in outputs["clusters"]:
                    for mention in cluster:
                        for index in range(mention[0], mention[1] + 1):
                            flipped.append(index)

            while True:
                # Compute L2 norm of all grads.
                grad = grads[grad_input_field][0]
                grads_magnitude = [g.dot(g) for g in grad]

                # only flip a token once
                for index in flipped:
                    grads_magnitude[index] = -1

                # We flip the token with highest gradient norm.
                index_of_token_to_flip = numpy.argmax(grads_magnitude)
                if grads_magnitude[index_of_token_to_flip] == -1:
                    # If we've already flipped all of the tokens, we give up.
                    break
                flipped.append(index_of_token_to_flip)

                text_field_tensors = text_field.as_tensor(
                    text_field.get_padding_lengths())
                input_tokens = util.get_token_ids_from_text_field_tensors(
                    text_field_tensors)
                original_id_of_token_to_flip = input_tokens[
                    index_of_token_to_flip]

                # Get new token using taylor approximation.
                new_id = self._first_order_taylor(
                    grad[index_of_token_to_flip], original_id_of_token_to_flip,
                    sign)

                # Flip token.  We need to tell the instance to re-index itself, so the text field
                # will actually update.
                new_token = Token(self.vocab._index_to_token[self.namespace]
                                  [new_id])  # type: ignore
                text_field.tokens[index_of_token_to_flip] = new_token
                instance.indexed = False

                # Get model predictions on instance, and then label the instances
                grads, outputs = self.predictor.get_gradients(
                    [instance])  # predictions
                for key, output in outputs.items():
                    if isinstance(output, torch.Tensor):
                        outputs[key] = output.detach().cpu().numpy().squeeze()
                    elif isinstance(output, list):
                        outputs[key] = output[0]

                # TODO(mattg): taking the first result here seems brittle, if we're in a case where
                # there are multiple predictions.
                labeled_instance = self.predictor.predictions_to_labeled_instances(
                    instance, outputs)[0]

                # If we've met our stopping criterion, we stop.
                has_changed = utils.instance_has_changed(
                    labeled_instance, fields_to_compare)
                if target is None and has_changed:
                    # With no target, we just want to change the prediction.
                    break
                if target is not None and not has_changed:
                    # With a given target, we want to *match* the target, which we check by
                    # `not has_changed`.
                    break

            final_tokens.append(text_field.tokens)

        return sanitize({
            "final": final_tokens,
            "original": original_tokens,
            "outputs": outputs
        })
Ejemplo n.º 30
0
    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : TextFieldTensors, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: torch.LongTensor, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : torch.LongTensor, optional (default = None)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = None)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : torch.FloatTensor
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : torch.FloatTensor
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        logits = self.tag_projection_layer(embedded_text_input)

        reshaped_log_probs = logits.view(-1, self.num_classes)
        class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(
            [batch_size, sequence_length, self.num_classes])
        output_dict = {
            "logits": logits,
            "class_probabilities": class_probabilities
        }
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = sequence_cross_entropy_with_logits(
                logits, tags, mask, label_smoothing=self._label_smoothing)
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.make_output_human_readable(
                    output_dict).pop("tags")
                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict