Beispiel #1
0
    def test_beam_search_matches_greedy(self):
        model = self.trained_model
        state = model._states["xintent"]
        beam_search = BeamSearch(model._end_index,
                                 max_steps=model._max_decoding_steps,
                                 beam_size=1)

        final_encoder_output = self.get_sample_encoded_output()
        batch_size = final_encoder_output.size()[0]
        start_predictions = final_encoder_output.new_full(
                (batch_size,), fill_value=model._start_index, dtype=torch.long)
        start_state = {"decoder_hidden": final_encoder_output}

        greedy_prediction = model.greedy_predict(
                final_encoder_output=final_encoder_output,
                target_embedder=state.embedder,
                decoder_cell=state.decoder_cell,
                output_projection_layer=state.output_projection_layer
        )
        greedy_tokens = model.decode_all(greedy_prediction)

        (beam_predictions, _) = beam_search.search(
                start_predictions, start_state, state.take_step)
        beam_prediction = beam_predictions[0]
        beam_tokens = model.decode_all(beam_prediction)

        assert beam_tokens == greedy_tokens
    def test_greedy_decode_matches_beam_search(self):
        # pylint: disable=protected-access
        beam_search = BeamSearch(self.model._end_index, max_steps=self.model._max_decoding_steps, beam_size=1)
        training_tensors = self.dataset.as_tensor_dict()

        # Get greedy predictions from _forward_loop method of model.
        state = self.model._encode(training_tensors["source_tokens"])
        state = self.model._init_decoder_state(state)
        output_dict_greedy = self.model._forward_loop(state)
        output_dict_greedy = self.model.decode(output_dict_greedy)

        # Get greedy predictions from beam search (beam size = 1).
        state = self.model._encode(training_tensors["source_tokens"])
        state = self.model._init_decoder_state(state)
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self.model._start_index)
        all_top_k_predictions, _ = beam_search.search(
                start_predictions, state, self.model.take_step)
        output_dict_beam_search = {
                "predictions": all_top_k_predictions,
        }
        output_dict_beam_search = self.model.decode(output_dict_beam_search)

        # Predictions from model._forward_loop and beam_search should match.
        assert output_dict_greedy["predicted_tokens"] == output_dict_beam_search["predicted_tokens"]
Beispiel #3
0
    def test_greedy_decode_matches_beam_search(self):

        beam_search = BeamSearch(
            self.model._end_index, max_steps=self.model._max_decoding_steps, beam_size=1
        )
        training_tensors = self.dataset.as_tensor_dict()

        # Get greedy predictions from _forward_loop method of model.
        state = self.model._encode(training_tensors["source_tokens"])
        state = self.model._init_decoder_state(state)
        output_dict_greedy = self.model._forward_loop(state)
        output_dict_greedy = self.model.make_output_human_readable(output_dict_greedy)

        # Get greedy predictions from beam search (beam size = 1).
        state = self.model._encode(training_tensors["source_tokens"])
        state = self.model._init_decoder_state(state)
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size,), fill_value=self.model._start_index, dtype=torch.long
        )
        all_top_k_predictions, _ = beam_search.search(
            start_predictions, state, self.model.take_step
        )
        output_dict_beam_search = {"predictions": all_top_k_predictions}
        output_dict_beam_search = self.model.make_output_human_readable(output_dict_beam_search)

        # Predictions from model._forward_loop and beam_search should match.
        assert output_dict_greedy["predicted_tokens"] == output_dict_beam_search["predicted_tokens"]
    def __init__(
        self,
        vocab: Vocabulary,
        decoder_net: DecoderNet,
        max_decoding_steps: int,
        target_embedder: Embedding,
        target_namespace: str = "tokens",
        tie_output_embedding: bool = False,
        scheduled_sampling_ratio: float = 0,
        label_smoothing_ratio: Optional[float] = None,
        beam_size: int = 4,
        tensor_based_metric: Metric = None,
        token_based_metric: Metric = None,
    ) -> None:
        super().__init__(target_embedder)

        self._vocab = vocab

        # Decodes the sequence of encoded hidden states into e new sequence of hidden states.
        self._decoder_net = decoder_net
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace
        self._label_smoothing_ratio = label_smoothing_ratio

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self._vocab.get_token_index(
            START_SYMBOL, self._target_namespace)
        self._end_index = self._vocab.get_token_index(END_SYMBOL,
                                                      self._target_namespace)
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        target_vocab_size = self._vocab.get_vocab_size(self._target_namespace)

        if self.target_embedder.get_output_dim(
        ) != self._decoder_net.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input."
            )

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(
            self._decoder_net.get_output_dim(), target_vocab_size)

        if tie_output_embedding:
            if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape:
                raise ConfigurationError(
                    "Can't tie embeddings with output linear layer, due to shape mismatch"
                )
            self._output_projection_layer.weight = self.target_embedder.weight

        # These metrics will be updated during training and validation
        self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric

        self._scheduled_sampling_ratio = scheduled_sampling_ratio
Beispiel #5
0
    def __init__(
        self,
        word_embedder: TextFieldEmbedder,
        attribute_embedder: Embedding,
        content_encoder: Seq2SeqEncoder,
        vocab: Vocabulary,
        max_decoding_steps: int = 20,
        beam_size: int = None,
        scheduled_sampling_ratio: float = 0.,
    ) -> None:
        super().__init__(vocab)

        self.scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self.start_index = self.vocab.get_token_index(START_SYMBOL, 'tokens')
        self.end_index = self.vocab.get_token_index(END_SYMBOL, 'tokens')

        # TODO: not sure if we need this
        self.bleu = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self.max_decoding_steps = max_decoding_steps
        self.beam_search = BeamSearch(self.end_index,
                                      max_steps=max_decoding_steps,
                                      beam_size=beam_size)

        # Dense embedding of source and target vocab tokens and attribute.
        self.word_embedder = word_embedder
        self.attribute_embedder = attribute_embedder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self.content_encoder = content_encoder

        num_classes = self.vocab.get_vocab_size('tokens')

        # TODO: not sure if we need this
        self.attention = None

        # Dense embedding of vocab words in the target space.
        embedding_dim = word_embedder.get_output_dim()
        self.target_embedder = Embedding(num_classes, embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self.encoder_output_dim = self.content_encoder.get_output_dim(
        ) + embedding_dim
        self.decoder_output_dim = self.encoder_output_dim

        self.decoder_input_dim = embedding_dim

        self.decoder_cell = LSTMCell(self.decoder_input_dim,
                                     self.decoder_output_dim)

        self.output_projection_layer = Linear(self.decoder_output_dim,
                                              num_classes)
Beispiel #6
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        """
        # Parameters

        model_name : `str`, required
            Name of the pre-trained BART model to use. Available options can be found in
            `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`.
        vocab : `Vocabulary`, required
            Vocabulary containing source and target vocabularies.
        indexer : `PretrainedTransformerIndexer`, optional (default = `None`)
            Indexer to be used for converting decoded sequences of ids to to sequences of tokens.
        max_decoding_steps : `int`, optional (default = `128`)
            Number of decoding steps during beam search.
        beam_size : `int`, optional (default = `5`)
            Number of beams to use in beam search. The default is from the BART paper.
        encoder : `Seq2SeqEncoder`, optional (default = `None`)
            Encoder to used in BART. By default, the original BART encoder is used.
        """
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
Beispiel #7
0
    def test_different_per_node_beam_size(self):
        # per_node_beam_size = 1
        beam_search = BeamSearch(self.end_index,
                                 beam_size=3,
                                 per_node_beam_size=1)
        self._check_results(beam_search=beam_search)

        # per_node_beam_size = 2
        beam_search = BeamSearch(self.end_index,
                                 beam_size=3,
                                 per_node_beam_size=2)
        self._check_results(beam_search=beam_search)
Beispiel #8
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        target_names = target_names or ["xintent", "xreact", "oreact"]

        super(Event2Mind, self).__init__(vocab)

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim()

        self._states: Dict[str, StateDecoder] = {}
        for name in target_names:
            self._states[name] = StateDecoder(
                    name,
                    self,
                    num_classes,
                    target_embedding_dim,
                    self._decoder_output_dim
            )

        self._beam_search = BeamSearch(
                self._end_index,
                beam_size=beam_size,
                max_steps=max_decoding_steps
        )
    def __init__(
        self,
        vocab: Vocabulary,
        decoder_net: DecoderNet,
        max_decoding_steps: int,
        target_embedder: Embedding,
        target_namespace: str = "tokens",
        tie_output_embedding: bool = False,
        scheduled_sampling_ratio: float = 0,
        label_smoothing_ratio: Optional[float] = None,
        beam_size: int = 4,
        tensor_based_metric: Metric = None,
        token_based_metric: Metric = None,
    ) -> None:
        super().__init__(target_embedder)

        self._vocab = vocab

        self._decoder_net = decoder_net
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace
        self._label_smoothing_ratio = label_smoothing_ratio

        self._start_index = self._vocab.get_token_index(
            START_SYMBOL, self._target_namespace)
        self._end_index = self._vocab.get_token_index(END_SYMBOL,
                                                      self._target_namespace)
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        target_vocab_size = self._vocab.get_vocab_size(self._target_namespace)

        if self.target_embedder.get_output_dim(
        ) != self._decoder_net.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input."
            )

        self._output_projection_layer = Linear(
            self._decoder_net.get_output_dim(), target_vocab_size)

        if tie_output_embedding:
            if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape:
                raise ConfigurationError(
                    "Can't tie embeddings with output linear layer, due to shape mismatch"
                )
            self._output_projection_layer.weight = self.target_embedder.weight

        self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
Beispiel #10
0
    def setup_method(self):
        super().setup_method()
        self.end_index = transition_probabilities.size()[0] - 1
        self.beam_search = BeamSearch(self.end_index,
                                      max_steps=10,
                                      beam_size=3)

        # This is what the top k should look like for each item in the batch.
        self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5],
                                        [3, 4, 5, 5, 5]])

        # This is what the log probs should look like for each item in the batch.
        self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
Beispiel #11
0
 def test_empty_sequences(self):
     initial_predictions = torch.LongTensor(
         [self.end_index - 1, self.end_index - 1])
     beam_search = BeamSearch(self.end_index, beam_size=1)
     with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
         predictions, log_probs = beam_search.search(
             initial_predictions, {}, take_step_with_timestep)
     # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
     assert list(predictions.size()) == [2, 1, 1]
     # log probs hould have shape `(batch_size, beam_size)`.
     assert list(log_probs.size()) == [2, 1]
     assert (predictions == self.end_index).all()
     assert (log_probs == 0).all()
Beispiel #12
0
    def setUp(self):
        super(BeamSearchTest, self).setUp()
        self.end_index = transition_probabilities.size()[0] - 1
        self.beam_search = BeamSearch(self.end_index,
                                      max_steps=10,
                                      beam_size=3)

        # This is what the top k should look like for each item in the batch.
        self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5],
                                        [3, 4, 5, 5, 5]])

        # This is what the log probs should look like for each item in the batch.
        self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))  # pylint: disable=assignment-from-no-return
Beispiel #13
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_decoding_steps: int,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 projection_dim: int = None,
                 use_coverage: bool = False,
                 coverage_loss_weight: float = None) -> None:
        super(PointerGeneratorNetwork, self).__init__(vocab)

        self._target_namespace = target_namespace
        self._start_index = self.vocab.get_token_index(START_SYMBOL, target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, target_namespace)
        self._unk_index = self.vocab.get_token_index(DEFAULT_OOV_TOKEN, target_namespace)
        self._vocab_size = self.vocab.get_vocab_size(target_namespace)

        # Encoder
        self._source_embedder = source_embedder
        self._encoder = encoder
        self._encoder_output_dim = self._encoder.get_output_dim()

        # Decoder
        self._target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
        self._num_classes = self.vocab.get_vocab_size(target_namespace)
        self._target_embedder = Embedding(self._num_classes, self._target_embedding_dim)

        self._decoder_input_dim = self._encoder_output_dim + self._target_embedding_dim
        self._decoder_output_dim = self._encoder_output_dim
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        self._projection_dim = projection_dim or self._source_embedder.get_output_dim()
        self._hidden_projection_layer = Linear(self._decoder_output_dim, self._projection_dim)
        self._output_projection_layer = Linear(self._projection_dim, self._num_classes)

        self._p_gen_layer = Linear(self._decoder_output_dim * 3 + self._decoder_input_dim, 1)
        self._attention = attention
        self._use_coverage = use_coverage
        self._coverage_loss_weight = coverage_loss_weight
        self._eps = 1e-31

        # Decoding
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size or 1)
Beispiel #14
0
    def _check_results(
        self,
        batch_size: int = 5,
        expected_top_k: np.array = None,
        expected_log_probs: np.array = None,
        beam_search: BeamSearch = None,
        state: Dict[str, torch.Tensor] = None,
        take_step=take_step_with_timestep,
    ) -> None:
        expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
        expected_log_probs = (expected_log_probs if expected_log_probs
                              is not None else self.expected_log_probs)
        state = state or {}

        beam_search = beam_search or self.beam_search
        beam_size = beam_search.beam_size

        initial_predictions = torch.tensor([0] * batch_size)
        top_k, log_probs = beam_search.search(initial_predictions, state,
                                              take_step)  # type: ignore

        # top_k should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
        np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
        np.testing.assert_allclose(log_probs[0].numpy(),
                                   expected_log_probs,
                                   rtol=1e-6)
Beispiel #15
0
 def test_greedy_search(self):
     beam_search = BeamSearch(self.end_index, beam_size=1)
     expected_top_k = np.array([[1, 2, 3, 4, 5]])
     expected_log_probs = np.log(np.array([0.4]))  # pylint: disable=assignment-from-no-return
     self._check_results(expected_top_k=expected_top_k,
                         expected_log_probs=expected_log_probs,
                         beam_search=beam_search)
 def test_greedy_search(self):
     beam_search = BeamSearch(self.end_index, beam_size=1)
     expected_top_k = np.array([[1, 2, 3, 4, 5]])
     expected_log_probs = np.log(np.array([0.4]))
     self._check_results(expected_top_k=expected_top_k,
                         expected_log_probs=expected_log_probs,
                         beam_search=beam_search)
Beispiel #17
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sentence_encoder: Seq2VecEncoder,
                 claim_encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_steps: int = 100,
                 beam_size: int = 5,
                 beta: float = 1.0) -> None:
        super(Seq2SeqClaimRank, self).__init__(vocab)

        self.text_field_embedder = text_field_embedder
        self.sentence_encoder = sentence_encoder
        self.claim_encoder = TimeDistributed(claim_encoder)  # Handles additional sequence dim
        self.claim_encoder_dim = claim_encoder.get_output_dim()
        self.attention = attention
        self.decoder_embedding_dim = text_field_embedder.get_output_dim()
        self.max_steps = max_steps
        self.beam_size = beam_size
        self.beta = beta

        # self.target_embedder = torch.nn.Embedding(vocab.get_vocab_size(), decoder_embedding_dim)

        # Since we are using the sentence encoding as the initial hidden state to the decoder, the
        # decoder hidden dim must match the sentence encoder hidden dim.
        self.decoder_output_dim = sentence_encoder.get_output_dim()
        self.decoder_0_cell = torch.nn.LSTMCell(self.decoder_embedding_dim + self.claim_encoder_dim,
                                                self.decoder_output_dim)
        self.decoder_1_cell = torch.nn.LSTMCell(self.decoder_output_dim,
                                                self.decoder_output_dim)

        # When projecting out we will use attention to combine claim embeddings into a single
        # context embedding, this will be concatenated with the decoder cell output before being
        # fed to the projection layer. Hence the expected input size is:
        #   decoder output dim + claim encoder output dim
        projection_input_dim = self.decoder_output_dim + self.claim_encoder_dim
        self.output_projection_layer = torch.nn.Linear(projection_input_dim,
                                                       vocab.get_vocab_size())

        self._start_index = self.vocab.get_token_index('<s>')
        self._end_index = self.vocab.get_token_index('</s>')

        self.beam_search = BeamSearch(self._end_index, max_steps=max_steps, beam_size=beam_size)
        pad_index = vocab.get_token_index(vocab._padding_token)
        self.bleu = BLEU(exclude_indices={pad_index, self._start_index, self._end_index})
        self.avg_reconstruction_loss = Average()
        self.avg_claim_scoring_loss = Average()
Beispiel #18
0
 def test_default_from_params_params(self):
     beam_search = BeamSearch.from_params(
         Params({
             "beam_size": 2,
             "end_index": 7
         }))
     assert beam_search.beam_size == 2
     assert beam_search._end_index == 7
Beispiel #19
0
 def test_p_val(self, p):
     with pytest.raises(ConfigurationError):
         initial_predictions = torch.tensor([0] * 5)
         take_step = take_step_with_timestep
         beam_size = 3
         top_p, log_probs = BeamSearch.top_p_sampling(
             self.end_index, p=p,
             beam_size=beam_size).search(initial_predictions, {}, take_step)
Beispiel #20
0
 def test_params_no_sampling(self):
     beam_search = BeamSearch.from_params(
         Params({
             "beam_size": 2,
             "end_index": 7
         }))
     assert beam_search.beam_size == 2
     assert beam_search._end_index == 7
     assert beam_search.sampler is None
Beispiel #21
0
 def test_catch_bad_config(self):
     """
     If `per_node_beam_size` (which defaults to `beam_size`) is larger than
     the size of the target vocabulary, `BeamSearch.search` should raise
     a ConfigurationError.
     """
     beam_search = BeamSearch(self.end_index, beam_size=20)
     with pytest.raises(ConfigurationError):
         self._check_results(beam_search=beam_search)
Beispiel #22
0
 def test_early_stopping(self):
     """
     Checks case where beam search will reach `max_steps` before finding end tokens.
     """
     beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3)
     expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
     expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))  # pylint: disable=assignment-from-no-return
     self._check_results(expected_top_k=expected_top_k,
                         expected_log_probs=expected_log_probs,
                         beam_search=beam_search)
    def __init__(
        self,
        vocabulary: Vocabulary,
        image_feature_size: int,
        embedding_size: int,
        hidden_size: int,
        attention_projection_size: int,
        max_caption_length: int = 20,
        beam_size: int = 1,
    ) -> None:
        super().__init__()
        self._vocabulary = vocabulary

        self.image_feature_size = image_feature_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.attention_projection_size = attention_projection_size

        # Short hand variable names for convenience
        vocab_size = vocabulary.get_vocab_size()
        self._pad_index = vocabulary.get_token_index("@@UNKNOWN@@")
        self._boundary_index = vocabulary.get_token_index("@@BOUNDARY@@")

        self._embedding_layer = nn.Embedding(
            vocab_size, embedding_size, padding_idx=self._pad_index
        )

        self._updown_cell = UpDownCell(
            image_feature_size, embedding_size, hidden_size, attention_projection_size
        )

        self._output_layer = nn.Linear(hidden_size, vocab_size)
        self._log_softmax = nn.LogSoftmax(dim=1)

        # We use beam search to find the most likely caption during inference.
        self._beam_size = beam_size
        self._beam_search = BeamSearch(
            self._boundary_index,
            max_steps=max_caption_length,
            beam_size=beam_size,
            per_node_beam_size=beam_size // 2,
        )
        self._max_caption_length = max_caption_length
Beispiel #24
0
 def test_params_p_sampling(self):
     beam_search = BeamSearch.from_params(
         Params({
             "type": "top_p_sampling",
             "beam_size": 2,
             "end_index": 7,
             "p": 0.4,
         }))
     assert beam_search.beam_size == 2
     assert beam_search._end_index == 7
     assert beam_search.sampler is not None
Beispiel #25
0
 def test_params_sampling(self):
     beam_search = BeamSearch.from_params(
         Params({
             "sampler": {
                 "type": "top-k",
                 "k": 4,
             },
             "beam_size": 2,
             "end_index": 7,
         }))
     assert beam_search.beam_size == 2
     assert beam_search._end_index == 7
     assert beam_search.sampler is not None
Beispiel #26
0
    def test_k_val(self, k_val):
        with pytest.raises(ValueError):
            initial_predictions = torch.tensor([0] * 5)
            take_step = take_step_with_timestep
            beam_size = 3
            k_sampler = TopKSampler(k=k_val, with_replacement=True)

            top_k, log_probs = BeamSearch(self.end_index,
                                          beam_size=beam_size,
                                          max_steps=10,
                                          sampler=k_sampler).search(
                                              initial_predictions, {},
                                              take_step)
Beispiel #27
0
    def test_top_k_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep
        k_sampler = TopKSampler(k=5, with_replacement=True)

        top_k, log_probs = BeamSearch(self.end_index,
                                      beam_size=beam_size,
                                      max_steps=10,
                                      sampler=k_sampler).search(
                                          initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_k) & (top_k <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
Beispiel #28
0
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )
Beispiel #29
0
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model_path,
                 beam_size=5,
                 max_decoding_steps=140,
                 indexer=None):
        super().__init__(vocab)
        self.plm = MT5ForConditionalGeneration.from_pretrained(pretrained_model_path)
        self._indexer = indexer or PretrainedTransformerIndexer(pretrained_model_path, namespace="tokens")
        ##
        self._start_id = self.plm.config.decoder_start_token_id
        ##
        self._end_id = self.plm.config.eos_token_id  #
        self._decoder_start_id = self.plm.config.decoder_start_token_id
        self._end_id = self.plm.config.eos_token_id  #
        self._pad_id = self.plm.config.pad_token_id  #

        self._beam_search = BeamSearch(
            self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1
        )
        self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})
Beispiel #30
0
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        super().__init__(vocab)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim()

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(
                    num_classes,
                    target_embedding_dim,
                    self._decoder_output_dim
            )

        self._beam_search = BeamSearch(
                self._end_index,
                beam_size=beam_size,
                max_steps=max_decoding_steps
        )
    def __init__(self, vocab: Vocabulary, max_timesteps: int = 50, encoder_size: int = 14, encoder_dim: int = 512, 
                 embedding_dim: int = 64, attention_dim: int = 64, decoder_dim: int = 64, beam_size: int = 3, teacher_forcing: bool = True) -> None:
        super().__init__(vocab)
        
        self._max_timesteps = max_timesteps
        
        self._vocab_size = self.vocab.get_vocab_size()
        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        # POSSIBLE CHANGE LATER
        self._pad_index = self.vocab.get_token_index('@@PADDING@@')
        
        self._encoder_size = encoder_size
        self._encoder_dim = encoder_dim
        self._embedding_dim = embedding_dim
        self._attention_dim = attention_dim
        self._decoder_dim = decoder_dim
        
        self._beam_size = beam_size
        self._teacher_forcing = teacher_forcing

        self._init_h = nn.Linear(self._encoder_dim, self._decoder_dim)
        self._init_c = nn.Linear(self._encoder_dim, self._decoder_dim)
        
        self._resnet = torchvision.models.resnet18()
        modules = list(self._resnet.children())[:-2]
        self._encoder = nn.Sequential(
            *modules,
            nn.AdaptiveAvgPool2d(self._encoder_size)
        )

        self._decoder = ImageCaptioningDecoder(self._vocab_size, self._encoder_dim, self._embedding_dim, self._attention_dim, self._decoder_dim)
        
        self.beam_search = BeamSearch(self._end_index, self._max_timesteps, self._beam_size)
        
        self._bleu = BLEU(exclude_indices={self._start_index, self._end_index, self._pad_index})
        self._exprate = Exprate(self._end_index)
Beispiel #32
0
    def __init__(self,
                 vocab: Vocabulary,
                 encoder: Encoder,
                 decoder: CaptioningDecoder,
                 max_timesteps: int = 75,
                 teacher_forcing: bool = True,
                 scheduled_sampling_ratio: float = 1,
                 beam_size: int = 10) -> None:
        super().__init__(vocab)

        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        self._pad_index = self.vocab.get_token_index('@@PADDING@@')

        self._max_timesteps = max_timesteps
        self._teacher_forcing = teacher_forcing
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._beam_size = beam_size

        self._encoder = encoder
        self._decoder = decoder

        self._init_h = nn.Linear(self._encoder.get_output_dim(),
                                 self._decoder.get_input_dim())
        self._init_c = nn.Linear(self._encoder.get_output_dim(),
                                 self._decoder.get_input_dim())

        self.beam_search = BeamSearch(self._end_index, self._max_timesteps,
                                      self._beam_size)

        self._bleu = BLEU(exclude_indices={
            self._start_index, self._end_index, self._pad_index
        })
        self._exprate = Exprate(self._end_index, self.vocab)

        self._attention_weights = None
Beispiel #33
0
    def test_top_k_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep

        top_k, log_probs = BeamSearch.top_k_sampling(
            self.end_index, k=1,
            beam_size=beam_size).search(initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_k) & (top_k <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
Beispiel #34
0
    def _check_results(self,
                       batch_size: int = 5,
                       expected_top_k: np.array = None,
                       expected_log_probs: np.array = None,
                       beam_search: BeamSearch = None,
                       state: Dict[str, torch.Tensor] = None) -> None:
        expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
        expected_log_probs = expected_log_probs if expected_log_probs is not None else self.expected_log_probs
        state = state or {}

        beam_search = beam_search or self.beam_search
        beam_size = beam_search.beam_size

        initial_predictions = torch.tensor([0] * batch_size)  # pylint: disable=not-callable
        top_k, log_probs = beam_search.search(initial_predictions, state, take_step)  # type: ignore

        # top_k should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
        np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
        np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs)
Beispiel #35
0
class Event2Mind(Model):
    """
    This ``Event2Mind`` class is a :class:`Model` which takes an event
    sequence, encodes it, and then uses the encoded representation to decode
    several mental state sequences.

    It is based on `the paper by Rashkin et al.
    <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (``tokens``) or the target tokens can have a different namespace, in which case it needs to
        be specified as ``target_namespace``.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences.
    embedding_dropout: float, required
        The amount of dropout to apply after the source tokens have been embedded.
    encoder : ``Seq2VecEncoder``, required
        The encoder of the "encoder/decoder" model.
    max_decoding_steps : int, required
        Length of decoded sequences.
    beam_size : int, optional (default = 10)
        The width of the beam search.
    target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact'])
        Names of the target fields matching those in the ``Instance`` objects.
    target_namespace : str, optional (default = 'tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : int, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        super().__init__(vocab)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim()

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(
                    num_classes,
                    target_embedding_dim,
                    self._decoder_output_dim
            )

        self._beam_search = BeamSearch(
                self._end_index,
                beam_size=beam_size,
                max_steps=max_decoding_steps
        )

    def _update_recall(self,
                       all_top_k_predictions: torch.Tensor,
                       target_tokens: Dict[str, torch.LongTensor],
                       target_recall: UnigramRecall) -> None:
        targets = target_tokens["tokens"]
        target_mask = get_text_field_mask(target_tokens)
        # See comment in _get_loss.
        # TODO(brendanr): Do we need contiguous here?
        relevant_targets = targets[:, 1:].contiguous()
        relevant_mask = target_mask[:, 1:].contiguous()
        target_recall(
                all_top_k_predictions,
                relevant_targets,
                relevant_mask,
                self._end_index
        )

    def _get_num_decoding_steps(self,
                                target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int:
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end
            # symbol.  Either way, we don't have to process it. (To be clear,
            # we do still output and compare against the end symbol, but there
            # is no need to take the end symbol as input to the decoder.)
            return target_sequence_length - 1
        else:
            return self._max_decoding_steps

    @overrides
    def forward(self,  # type: ignore
                source: Dict[str, torch.LongTensor],
                **target_tokens: Dict[str, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the target sequences.

        Parameters
        ----------
        source : ``Dict[str, torch.LongTensor]``
            The output of ``TextField.as_array()`` applied on the source
            ``TextField``. This will be passed through a ``TextFieldEmbedder``
            and then through an encoder.
        target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``:
            Dictionary from name to output of ``Textfield.as_array()`` applied
            on target ``TextField``. We assume that the target tokens are also
            represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, embedding_dim)
        embedded_input = self._embedding_dropout(self._source_embedder(source))
        source_mask = get_text_field_mask(source)
        # (batch_size, encoder_output_dim)
        final_encoder_output = self._encoder(embedded_input, source_mask)
        output_dict = {}

        # Perform greedy search so we can get the loss.
        if target_tokens:
            if target_tokens.keys() != self._states.keys():
                target_only = target_tokens.keys() - self._states.keys()
                states_only = self._states.keys() - target_tokens.keys()
                raise Exception("Mismatch between target_tokens and self._states. Keys in " +
                                f"targets only: {target_only} Keys in states only: {states_only}")
            total_loss = 0
            for name, state in self._states.items():
                loss = self.greedy_search(
                        final_encoder_output=final_encoder_output,
                        target_tokens=target_tokens[name],
                        target_embedder=state.embedder,
                        decoder_cell=state.decoder_cell,
                        output_projection_layer=state.output_projection_layer
                )
                total_loss += loss
                output_dict[f"{name}_loss"] = loss

            # Use mean loss (instead of the sum of the losses) to be comparable to the paper.
            output_dict["loss"] = total_loss / len(self._states)

        # Perform beam search to obtain the predictions.
        if not self.training:
            batch_size = final_encoder_output.size()[0]
            for name, state in self._states.items():
                start_predictions = final_encoder_output.new_full(
                        (batch_size,), fill_value=self._start_index, dtype=torch.long)
                start_state = {"decoder_hidden": final_encoder_output}

                # (batch_size, 10, num_decoding_steps)
                all_top_k_predictions, log_probabilities = self._beam_search.search(
                        start_predictions, start_state, state.take_step)

                if target_tokens:
                    self._update_recall(all_top_k_predictions, target_tokens[name], state.recall)
                output_dict[f"{name}_top_k_predictions"] = all_top_k_predictions
                output_dict[f"{name}_top_k_log_probabilities"] = log_probabilities

        return output_dict

    def greedy_search(self,
                      final_encoder_output: torch.LongTensor,
                      target_tokens: Dict[str, torch.LongTensor],
                      target_embedder: Embedding,
                      decoder_cell: GRUCell,
                      output_projection_layer: Linear) -> torch.FloatTensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the cross entropy between this sequence and ``target_tokens``.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()`` applied on some target ``TextField``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._get_num_decoding_steps(target_tokens)
        targets = target_tokens["tokens"]
        decoder_hidden = final_encoder_output
        step_logits = []
        for timestep in range(num_decoding_steps):
            # See https://github.com/allenai/allennlp/issues/1134.
            input_choices = targets[:, timestep]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
        # (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        target_mask = get_text_field_mask(target_tokens)
        return self._get_loss(logits, targets, target_mask)

    def greedy_predict(self,
                       final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding,
                       decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [final_encoder_output.new_full(
                (batch_size,), fill_value=self._start_index, dtype=torch.long
        )]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]

    @staticmethod
    def _get_loss(logits: torch.LongTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.FloatTensor:
        """
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        relevant_targets = targets[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
        loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
        return loss

    def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
                                for x in indices]
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, List[List[str]]]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds fields for the tokens to the ``output_dict``.
        """
        for name in self._states:
            top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][0]
            output_dict[f"{name}_top_k_predicted_tokens"] = [self.decode_all(top_k_predicted_indices)]

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics = {}
        # Recall@10 needs beam search which doesn't happen during training.
        if not self.training:
            for name, state in self._states.items():
                all_metrics[name] = state.recall.get_metric(reset=reset)
        return all_metrics