def generate(self, code: torch.Tensor, label: torch.Tensor,
                 max_length: int, beam_size: int,
                 lp_alpha: float) -> torch.Tensor:
        start_index = self.vocab.get_token_index('<s>')
        end_index = self.vocab.get_token_index('</s>')
        beam_search = BeamSearch(end_index=end_index,
                                 max_steps=max_length,
                                 beam_size=beam_size,
                                 per_node_beam_size=3)
        batch_size = code.shape[0]
        start_predictions = (
            torch.empty(batch_size).to(label).fill_(start_index))
        zero_state = code.new_zeros(batch_size,
                                    self._generator._module.hidden_size)
        start_state = {
            'h': zero_state,
            'c': zero_state,
            'code': code,
            'label': label,
            'length': label.new_ones(batch_size),
            'length_alpha': (code.new_empty(batch_size).fill_(lp_alpha))
        }

        all_predictions, last_log_probs = beam_search.search(
            start_predictions=start_predictions,
            start_state=start_state,
            step=self.beam_search_step)
        return all_predictions
Beispiel #2
0
    def test_beam_search_matches_greedy(self):
        model = self.trained_model
        state = model._states["xintent"]
        beam_search = BeamSearch(model._end_index,
                                 max_steps=model._max_decoding_steps,
                                 beam_size=1)

        final_encoder_output = self.get_sample_encoded_output()
        batch_size = final_encoder_output.size()[0]
        start_predictions = final_encoder_output.new_full(
                (batch_size,), fill_value=model._start_index, dtype=torch.long)
        start_state = {"decoder_hidden": final_encoder_output}

        greedy_prediction = model.greedy_predict(
                final_encoder_output=final_encoder_output,
                target_embedder=state.embedder,
                decoder_cell=state.decoder_cell,
                output_projection_layer=state.output_projection_layer
        )
        greedy_tokens = model.decode_all(greedy_prediction)

        (beam_predictions, _) = beam_search.search(
                start_predictions, start_state, state.take_step)
        beam_prediction = beam_predictions[0]
        beam_tokens = model.decode_all(beam_prediction)

        assert beam_tokens == greedy_tokens
Beispiel #3
0
    def test_greedy_decode_matches_beam_search(self):

        beam_search = BeamSearch(
            self.model._end_index, max_steps=self.model._max_decoding_steps, beam_size=1
        )
        training_tensors = self.dataset.as_tensor_dict()

        # Get greedy predictions from _forward_loop method of model.
        state = self.model._encode(training_tensors["source_tokens"])
        state = self.model._init_decoder_state(state)
        output_dict_greedy = self.model._forward_loop(state)
        output_dict_greedy = self.model.decode(output_dict_greedy)

        # Get greedy predictions from beam search (beam size = 1).
        state = self.model._encode(training_tensors["source_tokens"])
        state = self.model._init_decoder_state(state)
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size,), fill_value=self.model._start_index
        )
        all_top_k_predictions, _ = beam_search.search(
            start_predictions, state, self.model.take_step
        )
        output_dict_beam_search = {"predictions": all_top_k_predictions}
        output_dict_beam_search = self.model.decode(output_dict_beam_search)

        # Predictions from model._forward_loop and beam_search should match.
        assert output_dict_greedy["predicted_tokens"] == output_dict_beam_search["predicted_tokens"]
Beispiel #4
0
    def _check_results(
        self,
        batch_size: int = 5,
        expected_top_k: np.array = None,
        expected_log_probs: np.array = None,
        beam_search: BeamSearch = None,
        state: Dict[str, torch.Tensor] = None,
        take_step=take_step_with_timestep,
    ) -> None:
        expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
        expected_log_probs = (expected_log_probs if expected_log_probs
                              is not None else self.expected_log_probs)
        state = state or {}

        beam_search = beam_search or self.beam_search
        beam_size = beam_search.beam_size

        initial_predictions = torch.tensor([0] * batch_size)
        top_k, log_probs = beam_search.search(initial_predictions, state,
                                              take_step)  # type: ignore

        # top_k should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
        np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
        np.testing.assert_allclose(log_probs[0].numpy(),
                                   expected_log_probs,
                                   rtol=1e-6)
    def test_beam_search_matches_greedy(self):
        model = self.trained_model
        state = model._states["xintent"]
        beam_search = BeamSearch(model._end_index,
                                 max_steps=model._max_decoding_steps,
                                 beam_size=1)

        final_encoder_output = self.get_sample_encoded_output()
        batch_size = final_encoder_output.size()[0]
        start_predictions = final_encoder_output.new_full(
            (batch_size, ), fill_value=model._start_index, dtype=torch.long)
        start_state = {"decoder_hidden": final_encoder_output}

        greedy_prediction = model.greedy_predict(
            final_encoder_output=final_encoder_output,
            target_embedder=state.embedder,
            decoder_cell=state.decoder_cell,
            output_projection_layer=state.output_projection_layer,
        )
        greedy_tokens = model.decode_all(greedy_prediction)

        (beam_predictions, _) = beam_search.search(start_predictions,
                                                   start_state,
                                                   state.take_step)
        beam_prediction = beam_predictions[0]
        beam_tokens = model.decode_all(beam_prediction)

        assert beam_tokens == greedy_tokens
    def test_greedy_decode_matches_beam_search(self):
        # pylint: disable=protected-access
        beam_search = BeamSearch(self.model._end_index, max_steps=self.model._max_decoding_steps, beam_size=1)
        training_tensors = self.dataset.as_tensor_dict()

        # Get greedy predictions from _forward_loop method of model.
        state = self.model._encode(training_tensors["source_tokens"])
        state = self.model._init_decoder_state(state)
        output_dict_greedy = self.model._forward_loop(state)
        output_dict_greedy = self.model.decode(output_dict_greedy)

        # Get greedy predictions from beam search (beam size = 1).
        state = self.model._encode(training_tensors["source_tokens"])
        state = self.model._init_decoder_state(state)
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self.model._start_index)
        all_top_k_predictions, _ = beam_search.search(
                start_predictions, state, self.model.take_step)
        output_dict_beam_search = {
                "predictions": all_top_k_predictions,
        }
        output_dict_beam_search = self.model.decode(output_dict_beam_search)

        # Predictions from model._forward_loop and beam_search should match.
        assert output_dict_greedy["predicted_tokens"] == output_dict_beam_search["predicted_tokens"]
Beispiel #7
0
 def test_empty_sequences(self):
     initial_predictions = torch.LongTensor(
         [self.end_index - 1, self.end_index - 1])
     beam_search = BeamSearch(self.end_index, beam_size=1)
     with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
         predictions, log_probs = beam_search.search(
             initial_predictions, {}, take_step_with_timestep)
     # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
     assert list(predictions.size()) == [2, 1, 1]
     # log probs hould have shape `(batch_size, beam_size)`.
     assert list(log_probs.size()) == [2, 1]
     assert (predictions == self.end_index).all()
     assert (log_probs == 0).all()
Beispiel #8
0
    def _check_results(self,
                       batch_size: int = 5,
                       expected_top_k: np.array = None,
                       expected_log_probs: np.array = None,
                       beam_search: BeamSearch = None,
                       state: Dict[str, torch.Tensor] = None) -> None:
        expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
        expected_log_probs = expected_log_probs if expected_log_probs is not None else self.expected_log_probs
        state = state or {}

        beam_search = beam_search or self.beam_search
        beam_size = beam_search.beam_size

        initial_predictions = torch.tensor([0] * batch_size)  # pylint: disable=not-callable
        top_k, log_probs = beam_search.search(initial_predictions, state, take_step)  # type: ignore

        # top_k should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
        np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
        np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs)
class MASS(Model):
    """
    This ``MASS`` class is a :class:`Model` which takes a sequence, encodes it, and then
    uses the encoded representations to decode another sequence.  You can use this as the basis for
    a neural machine translation system, an abstractive summarization system, or any other common
    seq2seq problem.  The model here is simple, but should be a decent starting place for
    implementing recent models for these tasks.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : ``int``, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    attention : ``Attention``, optional (default = None)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    attention_function: ``SimilarityFunction``, optional (default = None)
        This is if you want to use the legacy implementation of attention. This will be deprecated
        since it consumes more memory than the specialized attention modules.
    beam_size : ``int``, optional (default = None)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : ``float``, optional (default = 0.)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015 <https://arxiv.org/abs/1506.03099>`_.
    use_bleu : ``bool``, optional (default = True)
        If True, the BLEU metric will be calculated during validation.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 transformer_encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 use_bleu: bool = True,
                 use_fp16: bool = False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(MASS, self).__init__(vocab)
        self._target_namespace = target_namespace
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._mask_index = self.vocab.get_token_index('[MASK]',
                                                      self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index, self._mask_index
            })
        else:
            self._bleu = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = transformer_encoder

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        )
        self._target_embedder = source_embedder

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        self._decoder_input_dim = target_embedding_dim

        self._decoder = TransformerDecoder(
            use_fp16,
            self._target_embedder,
            decoder_layers=6,
            dropout=0.1,
            decoder_embed_dim=self._encoder_output_dim,
            decoder_ffn_embed_dim=target_embedding_dim,
            decoder_attention_heads=4,
            decoder_output_dim=self._decoder_output_dim,
            max_target_positions=512,
            attention_dropout=0.1,
        )
        initializer(self)

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        start_predictions = {'tokens': last_predictions.unsqueeze(1)}
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            start_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)
        class_log_probabilities = class_log_probabilities.squeeze(1)

        return class_log_probabilities, state

    @overrides
    def forward(
        self,  # type: ignore
        encoder_tokens: Dict[str, torch.LongTensor],
        decoder_tokens: Dict[str, torch.LongTensor] = None,
        target_tokens: Dict[str, torch.LongTensor] = None,
        positions: torch.LongTensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        encoder_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(encoder_tokens)

        if decoder_tokens:
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state,
                                             decoder_tokens=decoder_tokens,
                                             target_tokens=target_tokens,
                                             positions=positions)
        else:
            output_dict = {}

        if not self.training:
            if not decoder_tokens:
                predictions = self._forward_beam_search(state)
                output_dict.update(predictions)
                if target_tokens and self._bleu:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]
                    self._bleu(best_predictions, target_tokens["tokens"])
            else:
                if target_tokens and self._bleu:
                    best_predictions = output_dict["predictions"]
                    self._bleu(best_predictions, target_tokens["tokens"])

        return output_dict

    def _encode(
            self,
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        return {
            "source_mask": source_mask,
            "encoder_outputs": encoder_outputs,
        }

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.

        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _forward_loop(
        self,
        state: Dict[str, torch.Tensor],
        decoder_tokens: Dict[str, torch.LongTensor] = None,
        target_tokens: Dict[str, torch.LongTensor] = None,
        positions: torch.LongTensor = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.

        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        decoder_token_mask = util.get_text_field_mask(decoder_tokens)
        decoder_padding_mask = (decoder_token_mask == 0)
        # shape: (batch_size, num_classes)
        logits, state = self._prepare_output_projections(
            decoder_tokens,
            state,
            decoder_padding_mask=decoder_padding_mask,
            positions=positions)
        # shape: (batch_size, num_classes)
        class_probabilities = F.softmax(logits, dim=-1)

        # shape (predicted_classes): (batch_size,)
        _, predictions = torch.max(class_probabilities, -1)

        output_dict = {"predictions": predictions}

        if target_tokens is not None:
            # Compute loss.
            target_mask = decoder_token_mask
            targets = target_tokens["tokens"]
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size(0)
        start_tokens = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._mask_index)
        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_tokens, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: Dict[str, torch.LongTensor],
                                    state: Dict[str, torch.Tensor],
                                    decoder_padding_mask: torch.LongTensor = None,
                                    positions: torch.LongTensor = None) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_out = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        encoder_padding_mask = (state["source_mask"] == 0)

        # shape: (group_size, target_embedding_dim)
        prev_output_tokens = last_predictions
        decoder_output = self._decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            encoder_padding_mask=encoder_padding_mask,
            decoder_padding_mask=decoder_padding_mask,
            positions=positions)

        return decoder_output, state

    @staticmethod
    def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        """
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                      1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        return util.sequence_cross_entropy_with_logits(logits, targets,
                                                       target_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not self.training:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        return all_metrics
class CustomAutoRegressiveSeqDecoder(SeqDecoder):
    def __init__(
        self,
        vocab: Vocabulary,
        decoder_net: DecoderNet,
        max_decoding_steps: int,
        target_embedder: Embedding,
        target_namespace: str = "tokens",
        tie_output_embedding: bool = False,
        scheduled_sampling_ratio: float = 0,
        label_smoothing_ratio: Optional[float] = None,
        beam_size: int = 4,
        tensor_based_metric: Metric = None,
        token_based_metric: Metric = None,
    ) -> None:
        super().__init__(target_embedder)

        self._vocab = vocab

        # Decodes the sequence of encoded hidden states into e new sequence of hidden states.
        self._decoder_net = decoder_net
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace
        self._label_smoothing_ratio = label_smoothing_ratio

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self._vocab.get_token_index(
            START_SYMBOL, self._target_namespace)
        self._end_index = self._vocab.get_token_index(END_SYMBOL,
                                                      self._target_namespace)
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        target_vocab_size = self._vocab.get_vocab_size(self._target_namespace)

        if self.target_embedder.get_output_dim(
        ) != self._decoder_net.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input."
            )

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(
            self._decoder_net.get_output_dim(), target_vocab_size)

        if tie_output_embedding:
            if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape:
                raise ConfigurationError(
                    "Can't tie embeddings with output linear layer, due to shape mismatch"
                )
            self._output_projection_layer.weight = self.target_embedder.weight

        # These metrics will be updated during training and validation
        self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric

        self._scheduled_sampling_ratio = scheduled_sampling_ratio

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _forward_loss(
            self, state: Dict[str, torch.Tensor],
            target_tokens: Dict[str,
                                torch.LongTensor]) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.

        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (batch_size, max_target_sequence_length)
        targets = target_tokens["tokens"]

        # Prepare embeddings for targets. They will be used as gold embeddings during decoder training
        # shape: (batch_size, max_target_sequence_length, embedding_dim)
        target_embedding = self.target_embedder(targets)

        # shape: (batch_size, max_target_batch_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)

        if self._scheduled_sampling_ratio == 0 and self._decoder_net.decodes_parallel:
            _, decoder_output = self._decoder_net(
                previous_state=state,
                previous_steps_predictions=target_embedding[:, :-1, :],
                encoder_outputs=encoder_outputs,
                source_mask=source_mask,
                previous_steps_mask=target_mask[:, :-1])

            # shape: (group_size, max_target_sequence_length, num_classes)
            logits = self._output_projection_layer(decoder_output)
        else:
            batch_size = source_mask.size()[0]
            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1

            # Initialize target predictions with the start index.
            # shape: (batch_size,)
            last_predictions = source_mask.new_full(
                (batch_size, ), fill_value=self._start_index)

            # shape: (steps, batch_size, target_embedding_dim)
            steps_embeddings = torch.Tensor([])

            step_logits: List[torch.Tensor] = []

            for timestep in range(num_decoding_steps):
                if self.training and torch.rand(
                        1).item() < self._scheduled_sampling_ratio:
                    # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                    # during training.
                    # shape: (batch_size, steps, target_embedding_dim)
                    state['previous_steps_predictions'] = steps_embeddings

                    # shape: (batch_size, )
                    effective_last_prediction = last_predictions
                else:
                    # shape: (batch_size, )
                    effective_last_prediction = targets[:, timestep]

                    if timestep == 0:
                        state['previous_steps_predictions'] = torch.Tensor([])
                    else:
                        # shape: (batch_size, steps, target_embedding_dim)
                        state[
                            'previous_steps_predictions'] = target_embedding[:, :
                                                                             timestep]

                # shape: (batch_size, num_classes)
                output_projections, state = self._prepare_output_projections(
                    effective_last_prediction, state)

                # list of tensors, shape: (batch_size, 1, num_classes)
                step_logits.append(output_projections.unsqueeze(1))

                # shape (predicted_classes): (batch_size,)
                _, predicted_classes = torch.max(output_projections, 1)

                # shape (predicted_classes): (batch_size,)
                last_predictions = predicted_classes

                # shape: (batch_size, 1, target_embedding_dim)
                last_predictions_embeddings = self.target_embedder(
                    last_predictions).unsqueeze(1)

                # This step is required, since we want to keep up two different prediction history: gold and real
                if steps_embeddings.shape[-1] == 0:  # pylint: disable=unsubscriptable-object
                    # There is no previous steps, except for start vectors in ``last_predictions``
                    # shape: (group_size, 1, target_embedding_dim)
                    steps_embeddings = last_predictions_embeddings
                else:
                    # shape: (group_size, steps_count, target_embedding_dim)
                    steps_embeddings = torch.cat(
                        [steps_embeddings, last_predictions_embeddings], 1)

            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

        # Compute loss.
        target_mask = util.get_text_field_mask(target_tokens)
        loss = self._get_loss(logits, targets, target_mask)

        # TODO: We will be using beam search to get predictions for validation, but if beam size in 1
        # we could consider taking the last_predictions here and building step_predictions
        # and use that instead of running beam search again, if performance in validation is taking a hit
        output_dict = {'loss': loss}

        return output_dict

    def _prepare_output_projections(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, steps_count, decoder_output_dim)
        previous_steps_predictions = state.get("previous_steps_predictions")

        # shape: (batch_size, 1, target_embedding_dim)
        last_predictions_embeddings = self.target_embedder(
            last_predictions).unsqueeze(1)

        if previous_steps_predictions is None or previous_steps_predictions.shape[
                -1] == 0:
            # There is no previous steps, except for start vectors in ``last_predictions``
            # shape: (group_size, 1, target_embedding_dim)
            previous_steps_predictions = last_predictions_embeddings
        else:
            # shape: (group_size, steps_count, target_embedding_dim)
            previous_steps_predictions = torch.cat(
                [previous_steps_predictions, last_predictions_embeddings], 1)

        decoder_state, decoder_output = self._decoder_net(
            previous_state=state,
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_predictions=previous_steps_predictions)
        state["previous_steps_predictions"] = previous_steps_predictions

        # Update state with new decoder state, override previous state
        state.update(decoder_state)

        if self._decoder_net.decodes_parallel:
            decoder_output = decoder_output[:, -1, :]

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_output)

        return output_projections, state

    def _get_loss(self, logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        """
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(
            logits,
            relevant_targets,
            relevant_mask,
            label_smoothing=self._label_smoothing_ratio)

    def get_output_dim(self):
        return self._decoder_net.get_output_dim()

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            if self._tensor_based_metric is not None:
                all_metrics.update(
                    self._tensor_based_metric.get_metric(
                        reset=reset))  # type: ignore
            if self._token_based_metric is not None:
                all_metrics.update(
                    self._token_based_metric.get_metric(
                        reset=reset))  # type: ignore
        return all_metrics

    @overrides
    def forward(
        self,
        encoder_out: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        state = encoder_out
        decoder_init_state = self._decoder_net.init_decoder_state(state)
        state.update(decoder_init_state)

        output_dict = self._forward_loss(
            state, target_tokens) if target_tokens else {}

        if not self.training:
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

            if target_tokens:
                if self._tensor_based_metric is not None:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]
                    # shape: (batch_size, target_sequence_length)

                    self._tensor_based_metric(
                        best_predictions,
                        target_tokens["tokens"])  # type: ignore

                if self._token_based_metric is not None:
                    output_dict = self.decode(output_dict)
                    predicted_tokens = output_dict['predicted_tokens']

                    self._token_based_metric(
                        predicted_tokens,  # type: ignore
                        [y.text for y in target_tokens["tokens"][1:-1]])

        return output_dict

    @overrides
    def post_process(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self._vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict
Beispiel #11
0
class PhnMoChA(Model):
    """
    This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then
    uses the encoded representations to decode another sequence.  You can use this as the basis for
    a neural machine translation system, an abstractive summarization system, or any other common
    seq2seq problem.  The model here is simple, but should be a decent starting place for
    implementing recent models for these tasks.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_namespace : ``str``, optional (default = 'tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : ``int``, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    attention : ``Attention``, optional (default = None)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    attention_function: ``SimilarityFunction``, optional (default = None)
        This is if you want to use the legacy implementation of attention. This will be deprecated
        since it consumes more memory than the specialized attention modules.
    beam_size : ``int``, optional (default = None)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : ``float``, optional (default = 0.)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015 <https://arxiv.org/abs/1506.03099>`_.
    use_bleu : ``bool``, optional (default = True)
        If True, the BLEU metric will be calculated during validation.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        encoder: Seq2SeqEncoder,
        input_size: int,
        target_embedding_dim: int,
        decoder_hidden_dim: int,
        max_decoding_steps: int,
        max_decoding_ratio: float = 1.5,
        dep_parser: Model = None,
        pos_tagger: Model = None,
        cmvn: str = 'none',
        delta: int = 0,
        time_mask_width: int = 0,
        freq_mask_width: int = 0,
        time_mask_max_ratio: float = 0.0,
        dec_layers: int = 1,
        layerwise_pretraining: List[Tuple[int, int]] = None,
        cnn: Seq2SeqEncoder = None,
        conv_lstm: Seq2SeqEncoder = None,
        train_at_phn_level: bool = False,
        rnnt_layer: Model = None,
        phn_ctc_layer: Model = None,
        ctc_layer: Model = None,
        projection_layer: nn.Module = None,
        tie_proj: bool = False,
        att_ratio: float = 0.0,
        dep_ratio: float = 0.0,
        pos_ratio: float = 0.0,
        attention: Attention = None,
        attention_function: SimilarityFunction = None,
        latency_penalty: float = 0.0,
        loss_type: str = "nll",
        beam_size: int = 1,
        target_namespace: str = "tokens",
        phoneme_target_namespace: str = "phonemes",
        dropout: float = 0.0,
        blank: str = "_",
        sampling_strategy: str = "max",
        from_candidates: bool = False,
        scheduled_sampling_ratio: float = 0.,
        initializer: InitializerApplicator = InitializerApplicator()
    ) -> None:
        super(PhnMoChA, self).__init__(vocab)
        self._input_size = input_size
        self._target_namespace = target_namespace
        self._phn_target_namespace = phoneme_target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._sampling_strategy = sampling_strategy
        self._train_at_phn_level = train_at_phn_level
        self._blank = blank

        self._dep_parser = dep_parser
        self._pos_tagger = pos_tagger
        self._ctc_layer = ctc_layer
        self._rnnt_layer = rnnt_layer
        self._phn_ctc_layer = phn_ctc_layer
        self._projection_layer = projection_layer
        if tie_proj:
            self._rnnt_layer.set_projection_layer(projection_layer)
        self._att_ratio = att_ratio
        self._dep_ratio = dep_ratio
        self._pos_ratio = pos_ratio
        self._loss_type = loss_type
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)  # pylint: disable=protected-access
        self._phn_pad_index = self.vocab.get_token_index(
            self.vocab._padding_token, self._phn_target_namespace)  # pylint: disable=protected-access

        exclude_indices = {self._pad_index, self._end_index, self._start_index}

        self._logs: Dict[str, Union[Metric, None]] = {
            "att_wer": (WER(exclude_indices=exclude_indices)
                        if self._att_ratio > 0 else None),
            "att_bleu": (BLEU(exclude_indices=exclude_indices)
                         if self._att_ratio > 0 else None),
            "att_loss": (Average() if self._att_ratio > 0 else None),
            "phn_ctc_loss": (Average() if self._phn_ctc_layer else None),
            "ctc_loss": (Average() if self._ctc_layer else None),
            "rnnt_loss": (Average() if self._rnnt_layer else None),
            "dal_loss": (Average() if latency_penalty > 0.0 else None),
            "dep_loss": (Average() if self._dep_parser else None),
            "pos_loss": (Average() if self._pos_tagger else None),
            "tag_loss": (Average() if self._dep_parser else None),
            "arc_loss": (Average() if self._dep_parser else None)
        }

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        self._max_decoding_steps = max_decoding_steps
        self._max_decoding_ratio = max_decoding_ratio
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        self._cnn = cnn
        self._conv_lstm = conv_lstm

        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        self._num_classes = num_classes

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError(
                    "You can only specify an attention module or an "
                    "attention function, but not both.")
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = decoder_hidden_dim
        self._dec_layers = dec_layers
        if self._decoder_output_dim != self._encoder_output_dim:
            self.bridge = nn.Linear(self._encoder_output_dim,
                                    self._dec_layers *
                                    self._decoder_output_dim,
                                    bias=False)

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
            self.att_out = Linear(self._decoder_output_dim +
                                  self._encoder_output_dim,
                                  self._decoder_output_dim,
                                  bias=True)
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder = nn.LSTM(self._decoder_input_dim,
                                self._decoder_output_dim,
                                num_layers=self._dec_layers,
                                batch_first=True)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

        self._input_norm = lambda x: x
        if cmvn == 'global':
            self._input_norm = nn.BatchNorm1d(self._input_size * (delta + 1))
        elif cmvn == 'utt':
            self._input_norm = nn.InstanceNorm1d(self._input_size *
                                                 (delta + 1))

        self._delta = None
        if delta > 0:
            self._delta = Delta(order=delta)

        self._epoch_num = float("inf")
        self._layerwise_pretraining = layerwise_pretraining
        try:
            if isinstance(self._encoder, PytorchSeq2SeqWrapper):
                self._num_layers = self._encoder._module.num_layers
            else:
                self._num_layers = self._encoder.num_layers
        except AttributeError:
            self._num_layers = float("inf")

        self._output_layer_num = self._num_layers

        self._loss = None

        self._from_candidates = from_candidates
        if loss_type == "ocd":
            self._loss = OCDLoss(self._end_index, 1e-7, 1e-7, 5)
        elif loss_type == "edocd":
            self._loss = EDOCDLoss(self._end_index, 1e-7, 1e-7, 5)

        self._latency_penalty = latency_penalty
        self._target_granularity = self._target_namespace

        self.time_mask = TimeMask(time_mask_width, time_mask_max_ratio)
        self.freq_mask = FreqMask(freq_mask_width)

        initializer(self)

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    @overrides
    def forward(
            self,  # type: ignore
            source_features: torch.FloatTensor,
            source_lengths: torch.LongTensor,
            target_tokens: Dict[str, torch.LongTensor] = None,
            words: Dict[str, torch.LongTensor] = None,
            segments: torch.LongTensor = None,
            pos_tags: torch.LongTensor = None,
            head_tags: torch.LongTensor = None,
            head_indices: torch.LongTensor = None,
            epoch_num: int = None,
            dataset: str = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        output_dict = {}
        if dataset is not None:
            self._target_granularity = dataset[0]

        if epoch_num is not None:
            self._epoch_num = epoch_num[0]
        self.set_output_layer_num()

        source_mask = util.get_mask_from_sequence_lengths(
            source_lengths, source_features.size(1)).bool()

        source_features = source_features.unsqueeze(1)  # make a channel dim
        if self._delta:
            source_features = self._delta(source_features)

        batch_size, n_channels, timesteps, feature_size = source_features.size(
        )
        source_features = self._input_norm(
            source_features.transpose(-2, -1).reshape(batch_size, -1, timesteps)) \
            .view(batch_size, n_channels, feature_size, timesteps).transpose(-2, -1)
        source_features = self.time_mask(source_features, source_mask)
        source_features = self.freq_mask(source_features, source_mask)

        source_features = source_features.masked_fill(
            ~source_mask.unsqueeze(1).unsqueeze(-1).expand_as(source_features),
            0.0)
        state = self._encode(source_features, source_lengths)
        source_lengths = util.get_lengths_from_binary_sequence_mask(
            state["source_mask"])
        target_tokens["mask"] = (target_tokens[self._target_namespace] !=
                                 self._pad_index).bool()

        if self._phn_ctc_layer and \
            (self._phn_target_namespace in self._target_granularity or self._train_at_phn_level):
            raise NotImplementedError
            # logits = self._projection_layer(state["encoder_outputs"])
            # phn_ctc_output_dict = self._phn_ctc_layer(logits, source_lengths, target_tokens)
            # output_dict.update({f"phn_ctc_{key}": value for key, value in phn_ctc_output_dict.items()})

        if self._rnnt_layer is not None and self._rnnt_layer.loss_ratio > 0.0:
            rnnt_output_dict = self._rnnt_layer(state["encoder_outputs"],
                                                source_lengths, target_tokens)
            output_dict.update({
                f"rnnt_{key}": value
                for key, value in rnnt_output_dict.items()
            })
        if self._ctc_layer is not None and self._ctc_layer.loss_ratio > 0.0:
            logits = self._projection_layer(state["encoder_outputs"])
            ctc_output_dict = self._ctc_layer(logits, source_lengths,
                                              target_tokens)
            output_dict.update({
                f"ctc_{key}": value
                for key, value in ctc_output_dict.items()
            })

        if target_tokens and self._att_ratio > 0.0 and \
            self._target_granularity == self._target_namespace:
            targets = target_tokens[self._target_namespace]
            output_dict["target_tokens"] = targets
            target_mask = util.get_text_field_mask(target_tokens)
            if self._train_at_phn_level:
                raise NotImplementedError
                # state = self._get_phn_level_representations(
                #     state["encoder_outputs"].detach().requires_grad_(True),
                #     state["source_mask"],
                #     output_dict["phn_ctc"])

            state = self._init_decoder_state(state)
            output_dict.update(self._forward_loop(state, target_tokens))
            self._logs["att_wer"](output_dict["predictions"], targets)

            if self._dep_parser or self._pos_tagger:
                relevant_mask = target_mask[:, 1:]
                attention_contexts, _ = _remove_eos(
                    output_dict["attention_contexts"], relevant_mask)
                if segments is not None:
                    segments, _ = remove_sentence_boundaries(
                        segments, target_mask)
                    attention_contexts, _ = \
                        char_to_word(attention_contexts, segments)
                contexts = {"tokens": attention_contexts}
                if self._dep_parser:
                    parser_outputs = self._dep_parser(contexts, pos_tags,
                                                      metadata, head_tags,
                                                      head_indices)
                    parser_outputs["dep_loss"] = parser_outputs.pop("loss")
                    output_dict.update(parser_outputs)
                if self._pos_tagger:
                    tagger_outputs = self._pos_tagger(contexts, pos_tags,
                                                      metadata)
                    tagger_outputs["pos_loss"] = tagger_outputs.pop("loss")
                    output_dict.update(tagger_outputs)

        if not self.training:
            if self._target_granularity == self._target_namespace:
                if self._att_ratio > 0.0:
                    state = self._init_decoder_state(state)
                    predictions = self._forward_beam_search(state)
                    output_dict.update(predictions)
                    if target_tokens:
                        targets = target_tokens[self._target_namespace]
                        # shape: (batch_size, beam_size, max_sequence_length)
                        top_k_predictions = output_dict["predictions"]
                        # shape: (batch_size, max_predicted_sequence_length)
                        best_predictions = top_k_predictions[:, 0, :]
                        self._logs["att_bleu"](best_predictions, targets)
                        self._logs["att_wer"](best_predictions, targets)
                    log_dict = self.decode(output_dict)
                    verbose_target = [
                        self._indices_to_tokens(tokens.tolist()[1:])
                        for tokens in target_tokens[self._target_namespace]
                    ]
                    verbose_best_pred = [
                        beams[0] for beams in log_dict["predicted_tokens"]
                    ]
                    sep = " " if self._target_namespace == 'tokens' else ""
                    with open(f"preds.{epoch_num[0]}.txt", "a+") as fp:
                        fp.write("\n".join([
                            sep.join(
                                map(lambda s: re.sub(self._blank, " ", s),
                                    words)) for words in verbose_best_pred
                        ]))
                        fp.write("\n")
                    with open(f"golds.{epoch_num[0]}.txt", "a+") as fp:
                        fp.write("\n".join([
                            sep.join(
                                map(lambda s: re.sub(self._blank, " ", s),
                                    words)) for words in verbose_target
                        ]))
                        fp.write("\n")
                    # for gold, pred in zip(verbose_target, verbose_best_pred):
                    #     print(gold, pred)

        if self.training:
            output_dict = self._collect_losses(
                output_dict,
                ctc=(self._ctc_layer.loss_ratio if self._ctc_layer else 0),
                rnnt=(self._rnnt_layer.loss_ratio if self._rnnt_layer else 0),
                att=self._att_ratio,
                dal=self._latency_penalty,
                dep=self._dep_ratio,
                pos=self._pos_ratio)
            if torch.isnan(output_dict["loss"]).any() or \
                    (torch.abs(output_dict["loss"]) == float('inf')).any():
                for key, _ in output_dict.items():
                    if "loss" in key:
                        output_dict[key] = output_dict[key].new_zeros(
                            size=(), requires_grad=True).clone()
        self._update_metrics(output_dict)

        return output_dict

    def _indices_to_tokens(self, indices):
        # Collect indices till the first end_symbol
        if self._end_index in indices:
            indices = indices[:indices.index(self._end_index)]
        predicted_tokens = [
            self.vocab.get_token_from_index(x,
                                            namespace=self._target_namespace)
            for x in indices
        ]
        return predicted_tokens

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.

        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        def _decode_predictions(input_key: str, output_key: str, beam=False):
            if input_key in output_dict:
                if beam:
                    all_predicted_tokens = [
                        list(map(self._indices_to_tokens, beams))
                        for beams in sanitize(output_dict[input_key])
                    ]
                else:
                    all_predicted_tokens = list(
                        map(self._indices_to_tokens,
                            sanitize(output_dict[input_key])))
                output_dict[output_key] = all_predicted_tokens

        _decode_predictions("predictions", "predicted_tokens", beam=True)
        _decode_predictions("ctc_predictions", "ctc_predicted_tokens")
        _decode_predictions("rnnt_predictions", "rnnt_predicted_tokens")
        _decode_predictions("target_tokens", "targets")

        return output_dict

    def _encode(self, source_features: torch.FloatTensor,
                source_lengths: torch.LongTensor) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        if self._cnn is not None:
            source_features, source_lengths = self._cnn(
                source_features, source_lengths)
        source_mask = util.get_mask_from_sequence_lengths(
            source_lengths, source_features.size(1))
        if self._conv_lstm is not None:
            source_features = self._conv_lstm(source_features, source_mask)
        if not isinstance(self._encoder, AWDRNN):
            encoder_outputs = self._encoder(source_features, source_mask)
        else:
            encoder_outputs, _, source_lengths = self._encoder(
                source_features, source_lengths, self._output_layer_num)
            source_mask = util.get_mask_from_sequence_lengths(
                source_lengths, encoder_outputs.size(1))
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}

    def _get_phn_level_representations(
            self, features: torch.FloatTensor, mask: torch.BoolTensor,
            phn_log_probs: torch.Tensor) -> Dict[str, torch.Tensor]:
        phn_enc_outs, segment_lengths = averaging_tensor_of_same_label(
            features, phn_log_probs, mask=mask)
        state = {
            "encoder_outputs":
            phn_enc_outs,
            "source_mask":
            util.get_mask_from_sequence_lengths(segment_lengths,
                                                int(max(segment_lengths)))
        }
        return state

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]
        source_mask = state["source_mask"]
        final_encoder_output = util.get_final_encoder_states(
            encoder_outputs, source_mask, self._encoder.is_bidirectional())
        if self._encoder_output_dim != self._dec_layers * self._decoder_output_dim:
            final_encoder_output = self.bridge(final_encoder_output)
        initial_decoder_input = final_encoder_output.view(-1, self._dec_layers,
                                                          self._decoder_output_dim) \
                                                          .contiguous()
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = initial_decoder_input
        state["decoder_output"] = initial_decoder_input[:, 0]
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = encoder_outputs.new_zeros(
            batch_size, self._dec_layers, self._decoder_output_dim)
        state["attention"] = None
        if isinstance(self._attention, StatefulAttention):
            state["att_keys"], state["att_values"] = \
                self._attention.init_state(encoder_outputs)

        return state

    def _forward_loop(
        self,
        state: Dict[str, torch.Tensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.

        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        candidates = None

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens[self._target_namespace]

            _, target_sequence_length = targets.size()

            if self._loss is not None:
                candidates = target_to_candidates(targets,
                                                  self._num_classes,
                                                  ignore_indices=[
                                                      self._pad_index,
                                                      self._start_index,
                                                      self._end_index
                                                  ])

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            if isinstance(self._loss, EDOCDLoss):
                num_decoding_steps = int(
                    target_sequence_length * self._max_decoding_ratio) - 1
            else:
                num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),
                                                fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        step_attns: List[torch.Tensor] = []
        step_attn_cxts: List[torch.Tensor] = []

        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif self._loss is not None:
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(
                input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))

            # list of tensors, shape: (batch_size, 1, num_encoding_steps)
            if self._attention:
                step_attns.append(state["attention"].unsqueeze(1))
                step_attn_cxts.append(state["attention_contexts"].unsqueeze(1))

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)

            predicted_classes = maybe_sample_from_candidates(
                class_probabilities,
                candidates=(candidates if self._from_candidates else None),
                strategy=self._sampling_strategy)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {
            "predictions": predictions,
        }

        # shape: (batch_size, num_decoding_steps, num_encoding_steps)
        if self._attention:
            output_dict["attentions"] = torch.cat(step_attns, 1)
            output_dict["attention_contexts"] = torch.cat(step_attn_cxts, 1)

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, predictions, targets, target_mask,
                                  candidates)

            output_dict["att_loss"] = loss

            if self._latency_penalty > 0.0:
                DAL = differentiable_average_lagging(output_dict["attentions"],
                                                     source_mask,
                                                     target_mask[:, 1:])
                output_dict["dal"] = DAL

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions
        }
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, decoder_output_dim)
        decoder_output = state["decoder_output"]

        attention = state["attention"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        # shape: (group_size, decoder_output_dim + target_embedding_dim)
        decoder_input = torch.cat((embedded_input, decoder_output), -1)

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        outputs, (decoder_hidden, decoder_context) = self._decoder(
            decoder_input.unsqueeze(1),
            (decoder_hidden.transpose(1, 0).contiguous(),
             decoder_context.transpose(1, 0).contiguous()))

        decoder_hidden = decoder_hidden.transpose(1, 0).contiguous()
        decoder_context = decoder_context.transpose(1, 0).contiguous()
        outputs = outputs.squeeze(1)
        if self._attention:
            # shape: (group_size, encoder_output_dim)
            attended_output, attention = self._prepare_attended_output(
                outputs, state)

            # shape: (group_size, decoder_output_dim)
            decoder_output = torch.tanh(
                self.att_out(torch.cat((attended_output, outputs), -1)))
            state["attention"] = attention
            state["attention_contexts"] = attended_output

        else:
            # shape: (group_size, target_embedding_dim)
            decoder_output = outputs

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context
        state["decoder_output"] = decoder_output

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_output)

        return output_projections, state

    def _prepare_attended_output(
            self, decoder_hidden_state: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length)

        encoder_outputs = state["encoder_outputs"]
        source_mask = state["source_mask"]
        prev_attention = state["attention"]
        att_keys = state["att_keys"]
        att_values = state["att_values"]

        # shape: (batch_size, max_input_sequence_length)
        mode = "soft" if self.training else "hard"
        if isinstance(self._attention, MonotonicAttention):
            encoder_outs: Dict[str, torch.Tensor] = {
                "value": state["encoder_outputs"],
                "mask": state["source_mask"]
            }

            monotonic_attention, chunk_attention = self._attention(
                encoder_outs, decoder_hidden_state, prev_attention, mode=mode)
            # shape: (batch_size, encoder_output_dim)
            attended_output = util.weighted_sum(encoder_outputs,
                                                chunk_attention)
            attention = monotonic_attention
        elif isinstance(self._attention, StatefulAttention):
            attended_output, attention = self._attention(
                decoder_hidden_state, att_keys, att_values, source_mask)
        else:
            attention = self._attention(decoder_hidden_state, source_mask)
            attended_output = util.weighted_sum(encoder_outputs, attention)

        return attended_output, attention

    # @staticmethod
    def _get_loss(self,
                  logits: torch.FloatTensor,
                  predictions: torch.LongTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor,
                  candidates: torch.LongTensor = None) -> torch.Tensor:
        """
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        if self._loss is not None:
            if isinstance(self._loss, OCDLoss) or isinstance(
                    self._loss, EDOCDLoss):
                self._loss.update_temperature(self._epoch_num)

            if isinstance(self._loss, EDOCDLoss):
                log_probs = F.log_softmax(logits, dim=-1)
                return self._loss(log_probs, predictions, relevant_targets,
                                  relevant_mask)
            else:
                raise NotImplementedError

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    def _collect_losses(self,
                        output_dict: Dict[str, torch.Tensor],
                        phn_ctc: float = 1.0,
                        ctc: float = 1.0,
                        rnnt: float = 1.0,
                        att: float = 1.0,
                        dep: float = 1.0,
                        pos: float = 1.0,
                        dal: float = 1.0) -> torch.Tensor:
        loss = 0.0
        if "phn_ctc_loss" in output_dict:
            loss += phn_ctc * output_dict["phn_ctc_loss"]
        if "ctc_loss" in output_dict:
            loss += ctc * output_dict["ctc_loss"]
        if "rnnt_loss" in output_dict:
            loss += rnnt * output_dict["rnnt_loss"]
        if "att_loss" in output_dict:
            loss += att * output_dict["att_loss"]
        if "dep_loss" in output_dict:
            loss += dep * output_dict["dep_loss"]
        if "pos_loss" in output_dict:
            loss += pos * output_dict["pos_loss"]
        if "dal" in output_dict:
            loss += dal * output_dict["dal"]

        output_dict["loss"] = loss
        return output_dict

    def _update_metrics(self, output_dict: Dict[str,
                                                torch.Tensor]) -> torch.Tensor:
        for key, track_func in self._logs.items():
            try:
                value = output_dict[key]
                value = value.item() if isinstance(value,
                                                   torch.Tensor) else value
                track_func(value)
            except KeyError:
                continue

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:

        all_metrics: Dict[str, float] = {}
        for key, metric_tracker in self._logs.items():
            if "phn" in key and self._phn_target_namespace not in self._target_granularity:
                continue
            if "att" in key and self._target_namespace not in self._target_granularity:
                continue
            if metric_tracker is not None:
                metric_values = metric_tracker.get_metric(reset=reset)
                if isinstance(metric_values, dict):
                    all_metrics.update(metric_values)
                else:
                    all_metrics[key] = metric_values
        if self._ctc_layer:
            all_metrics.update({
                f"ctc_{key}": value
                for key, value in self._ctc_layer.get_metrics(
                    reset=reset).items()
            })
        if self._rnnt_layer:
            all_metrics.update({
                f"rnnt_{key}": value
                for key, value in self._rnnt_layer.get_metrics(
                    reset=reset).items()
            })

        if not self.training:
            if self._dep_parser:
                all_metrics.update(self._dep_parser.get_metrics(reset=reset))
            if self._pos_tagger:
                all_metrics.update(self._pos_tagger.get_metrics(reset=reset))
        return all_metrics

    def set_output_layer_num(self):
        output_layer_num = self._num_layers
        if self._layerwise_pretraining is not None:
            for epoch, layer_num in self._layerwise_pretraining:
                if self._epoch_num < epoch:
                    break
                output_layer_num = layer_num
        self._output_layer_num = output_layer_num
        return output_layer_num
Beispiel #12
0
class Seq2SeqClaimRank(Model):
    """
    A ``Seq2SeqClaimRank`` model. This model is intended to be trained with a multi-instance
    learning objective that simultaneously tries to:
        - Decode the given post modifier (e.g. the ``target`` sequence).
        - Ensure that the model is attending to the proper claims during decoding (which are
        identified by the ``labels`` variable).
    The basic architecture is a seq2seq model with attention where the input sequence is the source
    sentence (without post-modifier), and the output sequence is the post-modifier. The main
    difference is that instead of performing attention over the input sequence, attention is
    performed over a collection of claims.

    Parameters
    ==========
    text_field_embedder : ``TextFieldEmbedder``
        Embeds words in the source sentence / claims.
    sentence_encoder : ``Seq2VecEncoder``
        Encodes the entire source sentence into a single vector.
    claim_encoder : ``Seq2SeqEncoder``
        Encodes each claim into a single vector.
    attention : ``Attention``
        Type of attention mechanism used.
        WARNING: Do not normalize attention scores, and make sure to use a
        sigmoid activation. Otherwise the claim ranking loss will not work
        properly!
    max_steps : ``int``
        Maximum number of decoding steps. Default: 100 (same as ONMT).
    beam_size: ``int``
        Beam size used during evaluation. Default: 5 (same as ONMT).
    beta: ``float``
        Weight of attention loss term.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sentence_encoder: Seq2VecEncoder,
                 claim_encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_steps: int = 100,
                 beam_size: int = 5,
                 beta: float = 1.0) -> None:
        super(Seq2SeqClaimRank, self).__init__(vocab)

        self.text_field_embedder = text_field_embedder
        self.sentence_encoder = sentence_encoder
        self.claim_encoder = TimeDistributed(claim_encoder)  # Handles additional sequence dim
        self.claim_encoder_dim = claim_encoder.get_output_dim()
        self.attention = attention
        self.decoder_embedding_dim = text_field_embedder.get_output_dim()
        self.max_steps = max_steps
        self.beam_size = beam_size
        self.beta = beta

        # self.target_embedder = torch.nn.Embedding(vocab.get_vocab_size(), decoder_embedding_dim)

        # Since we are using the sentence encoding as the initial hidden state to the decoder, the
        # decoder hidden dim must match the sentence encoder hidden dim.
        self.decoder_output_dim = sentence_encoder.get_output_dim()
        self.decoder_0_cell = torch.nn.LSTMCell(self.decoder_embedding_dim + self.claim_encoder_dim,
                                                self.decoder_output_dim)
        self.decoder_1_cell = torch.nn.LSTMCell(self.decoder_output_dim,
                                                self.decoder_output_dim)

        # When projecting out we will use attention to combine claim embeddings into a single
        # context embedding, this will be concatenated with the decoder cell output before being
        # fed to the projection layer. Hence the expected input size is:
        #   decoder output dim + claim encoder output dim
        projection_input_dim = self.decoder_output_dim + self.claim_encoder_dim
        self.output_projection_layer = torch.nn.Linear(projection_input_dim,
                                                       vocab.get_vocab_size())

        self._start_index = self.vocab.get_token_index('<s>')
        self._end_index = self.vocab.get_token_index('</s>')

        self.beam_search = BeamSearch(self._end_index, max_steps=max_steps, beam_size=beam_size)
        pad_index = vocab.get_token_index(vocab._padding_token)
        self.bleu = BLEU(exclude_indices={pad_index, self._start_index, self._end_index})
        self.avg_reconstruction_loss = Average()
        self.avg_claim_scoring_loss = Average()

    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        output_projections, _, state = self._prepare_output_projections(last_predictions, state)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)
        return class_log_probabilities, state

    @overrides
    def forward(self,
                inputs: Dict[str, torch.LongTensor],
                claims: Dict[str, torch.LongTensor],
                targets: Dict[str, torch.LongTensor] = None,
                labels: torch.Tensor = None) -> torch.Tensor:
        """Forward pass of the model + decoder logic.

        Parameters
        ----------
        inputs : ``Dict[str, torch.LongTensor]``
            Output of `TextField.as_array()` from the `input` field.
        claims : ``Dict[str, torch.LongTensor]``
            Output of `ListField.as_array()` from the `claims` field.
        targets : ``Dict[str, torch.LongTensor]``
            Output of `TextField.as_array()` from the `target` field.
            Only expected during training and validation.
        labels : ``torch.Tensor``
            Output of `LabelField.as_array()` from the `labels` field, indicating which claims were
            used.
            Only expected during training and validation.

        Returns
        -------
        Dict[str, torch.Tensor]
            Dictionary containing loss tensor and decoder outputs.
        """

        # Obtain an encoding for each input sentence (e.g. the contexts)
        input_mask = util.get_text_field_mask(inputs)
        input_word_embeddings = self.text_field_embedder(inputs)
        input_encodings = self.sentence_encoder(input_word_embeddings, input_mask)

        # Next we encode claims. Note that here we have two additional sequence dimensions (since
        # there are multiple claims per instance, and we want to apply attention at the word
        # level). To deal with this we need to set `num_wrapping_dims=1` for the embedder, and make
        # the claim encoder TimeDistributed.
        claim_mask = util.get_text_field_mask(claims, num_wrapping_dims=1)
        claim_word_embeddings = self.text_field_embedder(claims, num_wrapping_dims=1)
        claim_encodings = self.claim_encoder(claim_word_embeddings, claim_mask)

        # Package the encoder outputs into a state dictionary.
        state = {
            'input_mask': input_mask,
            'input_encodings': input_encodings,
            'claim_mask': claim_mask,
            'claim_encodings': claim_encodings
        }

        # If ``target`` (the post-modifier) and ``labels`` (indicator of which claims are used) are
        # provided then we use them to compute loss.
        if (targets is not None) and (labels is not None):
            state = self._init_decoder_state(state)
            output_dict = self._forward_loop(state, targets, labels)
        else:
            output_dict = {}

        # If model is not training, then we perform beam search for decoding to obtain higher
        # quality outputs.
        if not self.training:
            # Perform beam search
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            # Compute BLEU
            top_k_predictions = output_dict['predictions']
            best_predictions = top_k_predictions[:, 0, :]
            self.bleu(best_predictions, targets['tokens'])

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.
        """
        predicted_indices = output_dict['predictions']
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self.vocab.get_token_from_index(x) for x in indices]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _init_decoder_state(self,
                            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Adds fields to the state required to initialize the decoder."""

        batch_size = state['input_mask'].shape[0]

        # First decoder layer gets jack (trying to approximate the structure in
        # opennmt's graphic
        state['decoder_0_h'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim)
        state['decoder_0_c'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim)

        # Initialize LSTM hidden state (e.g. h_0) with output of the sentence encoder.
        state['decoder_1_h'] = state['input_encodings']
        # Initialize LSTM context state (e.g. c_0) with zeros.
        state['decoder_1_c'] = state['input_encodings'].new_zeros(batch_size, self.decoder_output_dim)
        # Initialize previous context.
        state['prev_context'] = state['input_encodings'].new_zeros(batch_size, self.claim_encoder_dim)
        return state

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      targets: Dict[str, torch.Tensor],
                      labels: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Compute loss using greedy decoding."""
        batch_size = state['input_mask'].shape[0]
        target_tokens = targets['tokens']
        num_decoding_steps = target_tokens.shape[1] - 1

        # Greedy decoding phase
        output_logit_list = []
        attention_logit_list = []
        select_idx_list = []
        for timestep in range(num_decoding_steps):
            # Feed target sequence as input
            decoder_input = target_tokens[:, timestep]
            output_logits, attention_logits, state = self._prepare_output_projections(decoder_input, state)
            # Store output and attention logits
            output_logit_list.append(output_logits.unsqueeze(1))
            attention_logit_list.append(attention_logits.unsqueeze(1))

        # Compute reconstruction loss
        output_logit_tensor = torch.cat(output_logit_list, dim=1)
        relevant_target_tokens = target_tokens[:, 1:].contiguous()
        target_mask = util.get_text_field_mask(targets)[:, 1:].contiguous()
        reconstruction_loss = util.sequence_cross_entropy_with_logits(output_logit_tensor,
                                                                      relevant_target_tokens,
                                                                      target_mask)

        # Compute claim scoring loss. A loss is computed between **each** attention vector and the
        # true label. In order for that to work we need to:
        #   a. Tile the source labels (so that they are copied for each word)
        #   b. Mask out padding tokens - this requires taking the outer-product of the target mask
        #       and the claim mask
        attention_logit_tensor = torch.cat(attention_logit_list, dim=1)
        claim_level_mask = (state['claim_mask'].sum(-1) > 0).long()
        attention_mask = target_mask.unsqueeze(-1) * claim_level_mask.unsqueeze(1)
        labels = labels.unsqueeze(1).repeat(1, num_decoding_steps, 1).float()
        claim_scoring_loss = F.binary_cross_entropy_with_logits(attention_logit_tensor, labels, reduction='none')
        claim_scoring_loss *= attention_mask.float()  # Apply mask

        # We want to apply 'batch' reduction (as is done in `sequence_cross_entropy...` which
        # entails averaging over each dimension.
        denom = attention_mask
        for i in range(3):
            denom = denom.sum(-1)
            claim_scoring_loss =  claim_scoring_loss.sum(-1) / (denom.float() + 1e-13)
            denom = (denom > 0)

        total_loss = reconstruction_loss + self.beta * claim_scoring_loss

        # Update metrics
        self.avg_reconstruction_loss(reconstruction_loss)
        self.avg_claim_scoring_loss(claim_scoring_loss)

        output_dict =  {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "claim_scoring_loss": claim_scoring_loss,
            "attention_logits": attention_logit_tensor
        }

        return output_dict

    def _prepare_output_projections(self,
                                    decoder_input: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        # Embed decoder input
        decoder_word_embeddings = self.text_field_embedder({'tokens': decoder_input})

        # Concat with previous context
        concat = torch.cat((decoder_word_embeddings, state['prev_context']), dim=-1)

        # Run forward pass of decoder RNN
        decoder_0_h, decoder_0_c = self.decoder_0_cell(concat, (state['decoder_0_h'], state['decoder_0_c']))
        decoder_1_h, decoder_1_c = self.decoder_1_cell(decoder_0_h, (state['decoder_1_h'], state['decoder_1_c']))
        state['decoder_0_h'] = decoder_0_h
        state['decoder_0_c'] = decoder_0_c
        state['decoder_1_h'] = decoder_1_h
        state['decoder_1_c'] = decoder_1_c

        # Compute attention and get context embedding. We get an attention score for each word in
        # each claim. Then we sum up scores to get a claim level score (so we can use overlap as
        # supervision).
        claim_encodings = state['claim_encodings']
        claim_mask = state['claim_mask']
        batch_size, n_claims, claim_length, dim = claim_encodings.shape

        flattened_claim_encodings = claim_encodings.view(batch_size, -1, dim)
        flattened_claim_mask = claim_mask.view(batch_size, -1)
        flattened_attention_logits = self.attention(decoder_1_h, flattened_claim_encodings, flattened_claim_mask)
        attention_logits = flattened_attention_logits.view(batch_size, n_claims, claim_length)

        # Now get claim level encodings by summing word level attention.
        word_level_attention = util.masked_softmax(attention_logits, claim_mask)
        claim_encodings = util.weighted_sum(claim_encodings, word_level_attention)

        # If not training, get max attention word to replace unk
        if not self.training:
            max_word = word_level_attention.argmax(dim=-1, keepdim=True)
            gathered = word_level_attention.gather(dim=-1, index=max_word)
            max_claim = gathered.squeeze().argmax(dim=-1, keepdim=True)
            max_word = max_word.squeeze().gather(dim=1, index=max_claim)
            select_idx = torch.cat((max_claim, max_word), dim=-1)
        else:
            select_idx = None

        # We compute our context directly from the claim word embeddings
        claim_mask = (claim_mask.sum(-1) > 0).float()
        attention_logits = attention_logits.sum(-1)
        attention_weights = torch.sigmoid(attention_logits) * claim_mask
        normalized_attention_weights = attention_weights / (attention_weights.sum(-1, True) + 1e-13)
        context_embedding = util.weighted_sum(claim_encodings, normalized_attention_weights)
        state['prev_context'] = context_embedding

        # Concatenate RNN output w/ context vector and feed through final hidden layer
        projection_input = torch.cat((decoder_1_h, context_embedding), dim=-1)
        output_logits = self.output_projection_layer(projection_input)

        return output_logits, attention_logits, state

    def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state['input_mask'].size()[0]
        start_predictions = state['input_mask'].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self.beam_search.search(
                start_predictions, state, self.take_step)

        output_dict = {
                "class_log_probabilities": log_probabilities,
                "predictions": all_top_k_predictions,
        }
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {
            'recon': self.avg_reconstruction_loss.get_metric(reset=reset).data.item(),
            'claim': self.avg_claim_scoring_loss.get_metric(reset=reset).data.item()
        }
        # Only update BLEU score during validation and evaluation
        if not self.training:
            all_metrics.update(self.bleu.get_metric(reset=reset))
        return all_metrics
class CustomAutoRegressiveSeqDecoder(SeqDecoder):
    def __init__(
        self,
        vocab: Vocabulary,
        decoder_net: DecoderNet,
        max_decoding_steps: int,
        target_embedder: Embedding,
        target_namespace: str = "tokens",
        tie_output_embedding: bool = False,
        scheduled_sampling_ratio: float = 0,
        label_smoothing_ratio: Optional[float] = None,
        beam_size: int = 4,
        tensor_based_metric: Metric = None,
        token_based_metric: Metric = None,
    ) -> None:
        super().__init__(target_embedder)

        self._vocab = vocab

        self._decoder_net = decoder_net
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace
        self._label_smoothing_ratio = label_smoothing_ratio

        self._start_index = self._vocab.get_token_index(
            START_SYMBOL, self._target_namespace)
        self._end_index = self._vocab.get_token_index(END_SYMBOL,
                                                      self._target_namespace)
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        target_vocab_size = self._vocab.get_vocab_size(self._target_namespace)

        if self.target_embedder.get_output_dim(
        ) != self._decoder_net.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input."
            )

        self._output_projection_layer = Linear(
            self._decoder_net.get_output_dim(), target_vocab_size)

        if tie_output_embedding:
            if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape:
                raise ConfigurationError(
                    "Can't tie embeddings with output linear layer, due to shape mismatch"
                )
            self._output_projection_layer.weight = self.target_embedder.weight

        self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _forward_loss(
            self, state: Dict[str, torch.Tensor],
            target_tokens: Dict[str,
                                torch.LongTensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (batch_size, max_target_sequence_length)
        targets = target_tokens["tokens"]

        # Prepare embeddings for targets. They will be used as gold embeddings during decoder training
        # shape: (batch_size, max_target_sequence_length, embedding_dim)
        target_embedding = self.target_embedder(targets)

        # shape: (batch_size, max_target_batch_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)

        if self._scheduled_sampling_ratio == 0 and self._decoder_net.decodes_parallel:
            _, decoder_output = self._decoder_net(
                previous_state=state,
                previous_steps_predictions=target_embedding[:, :-1, :],
                encoder_outputs=encoder_outputs,
                source_mask=source_mask,
                previous_steps_mask=target_mask[:, :-1])

            # shape: (group_size, max_target_sequence_length, num_classes)
            logits = self._output_projection_layer(decoder_output)
        else:
            batch_size = source_mask.size()[0]
            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1

            # Initialize target predictions with the start index.
            # shape: (batch_size,)
            last_predictions = source_mask.new_full(
                (batch_size, ), fill_value=self._start_index)

            # shape: (steps, batch_size, target_embedding_dim)
            steps_embeddings = torch.Tensor([])

            step_logits: List[torch.Tensor] = []

            for timestep in range(num_decoding_steps):
                if self.training and torch.rand(
                        1).item() < self._scheduled_sampling_ratio:
                    # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                    # during training.
                    # shape: (batch_size, steps, target_embedding_dim)
                    state['previous_steps_predictions'] = steps_embeddings

                    # shape: (batch_size, )
                    effective_last_prediction = last_predictions
                else:
                    # shape: (batch_size, )
                    effective_last_prediction = targets[:, timestep]

                    if timestep == 0:
                        state['previous_steps_predictions'] = torch.Tensor([])
                    else:
                        # shape: (batch_size, steps, target_embedding_dim)
                        state[
                            'previous_steps_predictions'] = target_embedding[:, :
                                                                             timestep]

                # shape: (batch_size, num_classes)
                output_projections, state = self._prepare_output_projections(
                    effective_last_prediction, state)

                # list of tensors, shape: (batch_size, 1, num_classes)
                step_logits.append(output_projections.unsqueeze(1))

                # shape (predicted_classes): (batch_size,)
                _, predicted_classes = torch.max(output_projections, 1)

                # shape (predicted_classes): (batch_size,)
                last_predictions = predicted_classes

                # shape: (batch_size, 1, target_embedding_dim)
                last_predictions_embeddings = self.target_embedder(
                    last_predictions).unsqueeze(1)

                # This step is required, since we want to keep up two different prediction history: gold and real
                if steps_embeddings.shape[-1] == 0:  # pylint: disable=unsubscriptable-object
                    # There is no previous steps, except for start vectors in ``last_predictions``
                    # shape: (group_size, 1, target_embedding_dim)
                    steps_embeddings = last_predictions_embeddings
                else:
                    # shape: (group_size, steps_count, target_embedding_dim)
                    steps_embeddings = torch.cat(
                        [steps_embeddings, last_predictions_embeddings], 1)

            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

        # Compute loss.
        target_mask = util.get_text_field_mask(target_tokens)
        loss = self._get_loss(logits, targets, target_mask)

        # TODO: We will be using beam search to get predictions for validation, but if beam size in 1
        # we could consider taking the last_predictions here and building step_predictions
        # and use that instead of running beam search again, if performance in validation is taking a hit
        output_dict = {'loss': loss}

        return output_dict

    def _prepare_output_projections(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, steps_count, decoder_output_dim)
        previous_steps_predictions = state.get("previous_steps_predictions")

        # shape: (batch_size, 1, target_embedding_dim)
        last_predictions_embeddings = self.target_embedder(
            last_predictions).unsqueeze(1)

        if previous_steps_predictions is None or previous_steps_predictions.shape[
                -1] == 0:
            # There is no previous steps, except for start vectors in ``last_predictions``
            # shape: (group_size, 1, target_embedding_dim)
            previous_steps_predictions = last_predictions_embeddings
        else:
            # shape: (group_size, steps_count, target_embedding_dim)
            previous_steps_predictions = torch.cat(
                [previous_steps_predictions, last_predictions_embeddings], 1)

        decoder_state, decoder_output = self._decoder_net(
            previous_state=state,
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_predictions=previous_steps_predictions)
        state["previous_steps_predictions"] = previous_steps_predictions

        # Update state with new decoder state, override previous state
        state.update(decoder_state)

        if self._decoder_net.decodes_parallel:
            decoder_output = decoder_output[:, -1, :]

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_output)

        return output_projections, state

    def _get_loss(self, logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(
            logits,
            relevant_targets,
            relevant_mask,
            label_smoothing=self._label_smoothing_ratio)

    def get_output_dim(self):
        return self._decoder_net.get_output_dim()

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            if self._tensor_based_metric is not None:
                all_metrics.update(
                    self._tensor_based_metric.get_metric(
                        reset=reset))  # type: ignore
            if self._token_based_metric is not None:
                all_metrics.update(
                    self._token_based_metric.get_metric(
                        reset=reset))  # type: ignore
        return all_metrics

    @overrides
    def forward(
        self,
        encoder_out: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        state = encoder_out
        decoder_init_state = self._decoder_net.init_decoder_state(state)
        state.update(decoder_init_state)

        output_dict = self._forward_loss(
            state, target_tokens) if target_tokens else {}

        if not self.training:
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

            if target_tokens:
                if self._tensor_based_metric is not None:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]
                    # shape: (batch_size, target_sequence_length)

                    self._tensor_based_metric(
                        best_predictions,
                        target_tokens["tokens"])  # type: ignore

                if self._token_based_metric is not None:
                    output_dict = self.decode(output_dict)
                    predicted_tokens = output_dict['predicted_tokens']

                    self._token_based_metric(
                        predicted_tokens,  # type: ignore
                        [y.text for y in target_tokens["tokens"][1:-1]])

        return output_dict

    @overrides
    def post_process(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self._vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict
Beispiel #14
0
class Editor(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        embed: TextFieldEmbedder,
        encoder_size: int,
        decoder_size: int,
        num_layers: int,
        beam_size: int,
        max_decoding_steps: int,
        use_bleu: bool = True,
        initializer: InitializerApplicator = InitializerApplicator()
    ) -> None:
        super().__init__(vocab)

        self.START, self.END = self.vocab.get_token_index(
            START_SYMBOL), self.vocab.get_token_index(END_SYMBOL)
        self.OOV = self.vocab.get_token_index(self.vocab._oov_token)  # pylint: disable=protected-access
        self.PAD = self.vocab.get_token_index(self.vocab._padding_token)  # pylint: disable=protected-access
        self.COPY = self.vocab.get_token_index("@@COPY@@")
        self.KEEP = self.vocab.get_token_index("@@KEEP@@")
        self.DROP = self.vocab.get_token_index("@@DROP@@")

        self.SYMBOL = (self.START, self.END, self.PAD, self.KEEP, self.DROP)
        self.vocab_size = vocab.get_vocab_size()
        self.EMB = embed

        self.emb_size = self.EMB.token_embedder_tokens.output_dim
        self.encoder_size, self.decoder_size = encoder_size, decoder_size
        self.FACT_ENCODER = FeedForward(3 * self.emb_size, 1, encoder_size,
                                        nn.Tanh())
        self.ATTN = AdditiveAttention(encoder_size + decoder_size,
                                      encoder_size)
        self.COPY_ATTN = AdditiveAttention(decoder_size, encoder_size)
        module = nn.LSTM(self.emb_size,
                         encoder_size // 2,
                         num_layers,
                         bidirectional=True,
                         batch_first=True)
        self.BUFFER = PytorchSeq2SeqWrapper(
            module)  # BiLSTM to encode draft text
        self.STREAM = nn.LSTMCell(2 * encoder_size,
                                  decoder_size)  # Store revised text

        self.BEAM = BeamSearch(self.END,
                               max_steps=max_decoding_steps,
                               beam_size=beam_size)

        self.U = nn.Sequential(nn.Linear(2 * encoder_size, decoder_size),
                               nn.Tanh())
        self.ADD = nn.Sequential(nn.Linear(self.emb_size, encoder_size),
                                 nn.Tanh())

        self.P = nn.Sequential(
            nn.Linear(encoder_size + decoder_size, decoder_size), nn.Tanh())
        self.W = nn.Linear(decoder_size, self.vocab_size)
        self.G = nn.Sequential(nn.Linear(decoder_size, 1), nn.Sigmoid())

        initializer(self)
        self._bleu = BLEU(
            exclude_indices=set(self.SYMBOL)) if use_bleu else None

    @overrides
    def forward(
            self,  # type: ignore
            metadata: List[Dict[str, Any]],
            triple_tokens: Dict[str, torch.LongTensor],
            triple_token_ids: torch.Tensor,
            predicate_tokens: Dict[str, torch.Tensor],
            draft_tokens: Dict[str, torch.LongTensor],
            action_tokens: Dict[str, torch.LongTensor] = None,
            revised_tokens: Dict[str, torch.LongTensor] = None,
            action_token_ids: torch.Tensor = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        state = self._init_state(triple_tokens, predicate_tokens, draft_tokens,
                                 triple_token_ids)
        if action_tokens:
            # Initialize Decoder
            state = self._decoder_init(state)
            output_dict = self._forward_loss(action_tokens, action_token_ids,
                                             state, **kwargs)
        else:
            output_dict = {}
        output_dict["metadata"] = metadata

        if not self.training:
            # Re-initialize decoder
            state = self._decoder_init(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

            if revised_tokens and self._bleu:
                top_k_predictions = output_dict["predictions"]
                best_actions = top_k_predictions[:, 0]
                best_predictions = self._action_to_token(
                    best_actions, draft_tokens["tokens"])
                gold_tokens = self._extend_gold_tokens(
                    revised_tokens["tokens"], action_tokens["tokens"],
                    triple_token_ids, action_token_ids)
                self._bleu(best_predictions, gold_tokens)

        return output_dict

    def _extend_gold_tokens(self, revised_tokens: torch.Tensor,
                            action_tokens: torch.Tensor,
                            triple_token_ids: torch.Tensor,
                            action_token_ids: torch.Tensor):
        batch_size, action_length = action_tokens.size()
        triple_size = triple_token_ids.size(1)
        expanded_triple_ids = triple_token_ids.unsqueeze(1).expand(
            batch_size, action_length, triple_size)
        expanded_revised_ids = action_token_ids.unsqueeze(-1).expand(
            batch_size, action_length, triple_size)
        match = expanded_triple_ids == expanded_revised_ids
        copied = match.sum(-1) > 0
        oov = action_tokens == self.OOV
        mask = (oov & copied).long()

        first_match = ((match.cumsum(-1) == 1) * match).byte().argmax(-1)
        new_action_tokens = action_tokens * (1 - mask) + (
            first_match.long() + self.vocab_size) * mask

        increment_mask = ~(new_action_tokens == self.DROP)
        pointer = revised_tokens.new_zeros((revised_tokens.size(0), ))
        end_point = ((revised_tokens != 0).sum(dim=1) - 1)

        for i in range(action_length):
            act_step, mask_step = new_action_tokens[:, i], mask[:, i].bool()
            revised_tokens[mask_step.nonzero().squeeze(1),
                           pointer[mask_step]] = act_step[mask_step]
            pointer[increment_mask[:, i]] += 1
            pointer = torch.min(pointer, end_point)
        return revised_tokens

    def _action_to_token(self, action_tokens: torch.LongTensor,
                         draft_tokens: torch.LongTensor) -> torch.LongTensor:
        predicted_pointer = action_tokens.new_zeros((draft_tokens.size(0), 1))
        draft_pointer = draft_tokens.new_ones((draft_tokens.size(0), 1))

        predicted_tokens = action_tokens.new_full((action_tokens.size()),
                                                  self.END)

        for act_step in action_tokens.t():
            # KEEP, DELETE, COPY, ADD (other)
            keep_mask = act_step == self.KEEP
            drop_mask = act_step == self.DROP
            add_mask = ~(keep_mask | drop_mask)

            predicted_tokens.scatter_(1, predicted_pointer,
                                      draft_tokens.gather(1, draft_pointer))
            predicted_tokens[add_mask] = predicted_tokens[add_mask].scatter(
                1, predicted_pointer[add_mask],
                act_step[add_mask].unsqueeze(1))

            draft_pointer[keep_mask | drop_mask] += 1
            predicted_pointer[~drop_mask] += 1
        return predicted_tokens

    def _decoder_init(self, state):
        mean_draft = util.masked_mean(state["encoded_draft"],
                                      state["draft_mask"].unsqueeze(-1), 1)
        mean_triple = util.masked_mean(state["encoded_triple"],
                                       state["triple_mask"].unsqueeze(-1), 1)
        concatenated = torch.cat((mean_draft, mean_triple), dim=-1)
        batch_size = state["draft_mask"].size(0)

        zeros = mean_draft.new_zeros((batch_size, self.decoder_size))
        state["stream_hidden"], state["stream_context"] = self.U(
            concatenated), zeros
        state["draft_pointer"] = state["draft_mask"].new_ones((batch_size, ))

        action_mask = mean_draft.new_ones((batch_size, self.vocab_size))
        action_mask[:, self.PAD] = 0
        action_mask[:, self.END] = 0

        state["action_mask"] = action_mask

        return state

    def _init_state(self, triples: Dict[str, torch.LongTensor],
                    predicate: Dict[str, torch.LongTensor],
                    draft: Dict[str, torch.LongTensor],
                    triple_ids: torch.LongTensor) -> Dict[str, torch.Tensor]:
        emb_pred = util.masked_mean(
            self.EMB(predicate),
            util.get_text_field_mask(
                predicate,
                num_wrapping_dims=1,
            ).unsqueeze(-1), 2)
        emb_triple = self.EMB(triples)
        triple_mask = util.get_text_field_mask(triples)
        flat_triples = torch.cat((emb_triple.flatten(2, 3), emb_pred), dim=-1)

        encoded_triples = self.FACT_ENCODER(flat_triples)

        emb_draft = self.EMB(draft)
        draft_mask = util.get_text_field_mask(draft)
        end_point = (draft_mask.sum(dim=1) - 1)
        encoded_draft = self.BUFFER(emb_draft, draft_mask)

        return {
            "draft_mask": draft_mask,
            "triple_mask": triple_mask,
            "end_point": end_point,
            "encoded_triple": encoded_triples,
            "encoded_draft": encoded_draft,
            "triple_tokens": triples["tokens"][:, :, -1],
            "triple_token_ids": triple_ids
        }

    def _forward_loss(
            self, target_actions: Dict[str, torch.LongTensor],
            target_token_ids: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size, target_sequence_length = target_actions["tokens"].size()
        num_decoding_steps = target_sequence_length - 1

        target_to_triple = state["triple_mask"].new_zeros(
            state["triple_mask"].size()).bool()
        copy_input_choice = state["triple_mask"].new_full((batch_size, ),
                                                          self.COPY)

        step_log_likelihoods = []
        for t in range(num_decoding_steps):
            input_actions = target_actions["tokens"][:, t]
            if t < num_decoding_steps - 1:
                copied = (target_to_triple.sum(dim=-1) > 0) & (input_actions
                                                               == self.OOV)
                target_to_triple = state[
                    "triple_token_ids"] == target_token_ids[:, t +
                                                            1].unsqueeze(-1)
                input_actions = copied.long() * (copy_input_choice -
                                                 input_actions) + input_actions

            state = self._decoder_step(input_actions, state)
            step_target_actions = target_actions["tokens"][:, t + 1]
            step_log_likelihoods.append(
                self._get_log_likelihood(state, step_target_actions,
                                         target_to_triple))

        log_likelihoods = torch.stack(step_log_likelihoods, dim=-1)
        target_mask = util.get_text_field_mask(target_actions)
        target_mask = target_mask[:, 1:].float()

        log_likelihood = (log_likelihoods * target_mask).sum(dim=-1)
        loss = -log_likelihood.sum()
        loss /= batch_size

        return {"loss": loss}

    @staticmethod
    def _get_query(state: Dict[str, torch.Tensor]):
        batch_size = state["encoded_draft"].size(0)
        buffer_head = state["encoded_draft"][torch.arange(batch_size),
                                             state["draft_pointer"]]

        query = torch.cat([buffer_head, state["stream_hidden"]], dim=1)
        return query

    def _get_log_likelihood(self, state: Dict[str, torch.Tensor],
                            target_actions: torch.Tensor,
                            target_to_source: torch.Tensor) -> torch.Tensor:
        hidden = self.P(self._get_query(state))
        gate_prob = self.G(hidden).squeeze(1)

        gen_prob = util.masked_softmax(self.W(hidden), state["action_mask"], memory_efficient=True)\
            .gather(1, target_actions.unsqueeze(1)).squeeze(1)
        gen_mask = (target_actions != self.OOV) | (target_to_source.sum(dim=-1)
                                                   == 0)
        gen_prob = gen_prob.min(gen_mask.float())

        copy_prob = self.COPY_ATTN(hidden, state["encoded_triple"], state["triple_mask"])\
            .masked_fill(~target_to_source, 0.).sum(dim=-1)

        step_prob = gen_prob * gate_prob + copy_prob * (-gate_prob + 1)
        step_log_likelihood = step_prob.clamp(1e-30).log()

        return step_log_likelihood

    def _decoder_step(
            self, last_actions: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        embed_actions = self.EMB({"tokens": last_actions})
        batch_size = embed_actions.size(0)
        # Update stack given draft pointer information

        draft_head = state["encoded_draft"][torch.arange(batch_size),
                                            state["draft_pointer"]]
        query = torch.cat([state["stream_hidden"], draft_head], dim=1)
        attend = self.ATTN(query, state["encoded_triple"],
                           state["triple_mask"])
        attended_triple = util.weighted_sum(state["encoded_triple"], attend)

        is_added = torch.stack([last_actions != tok
                                for tok in self.SYMBOL]).all(dim=0)
        draft_head[is_added] = self.ADD(embed_actions[is_added])

        hs, cs = self.STREAM(torch.cat((draft_head, attended_triple), dim=-1),
                             (state["stream_hidden"], state["stream_context"]))
        drop_mask = (last_actions != self.DROP).unsqueeze(1).float()
        hx = drop_mask * hs + (-drop_mask + 1) * state["stream_hidden"]
        cx = drop_mask * cs + (-drop_mask + 1) * state["stream_context"]
        state["stream_hidden"], state["stream_context"] = hx, cx

        # Update Pointer
        move_forward = ((last_actions == self.KEEP) |
                        (last_actions == self.DROP)).long()

        state["draft_pointer"] = state["draft_pointer"] + move_forward
        # Simple masking for pointer
        state["draft_pointer"] = torch.min(state["draft_pointer"],
                                           state["end_point"])

        is_ended = state["end_point"] == state["draft_pointer"]
        state["action_mask"][is_ended, self.KEEP] = 0
        state["action_mask"][is_ended, self.DROP] = 0
        state["action_mask"][is_ended, self.END] = 1

        return state

    def take_search_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        input_choices = self._get_input(last_predictions)
        state = self._decoder_step(input_choices, state)
        final_prob = self._make_prob(state)

        return final_prob.clamp(1e-30).log(), state

    def _get_input(
        self,
        last_predictions: torch.Tensor,
    ) -> torch.Tensor:
        group_size, = last_predictions.size()
        only_copy_mask = (last_predictions >= self.vocab_size).long()
        copy_input_choices = only_copy_mask.new_full((group_size, ), self.COPY)
        input_choices = (copy_input_choices -
                         last_predictions) * only_copy_mask + last_predictions
        return input_choices

    def _make_prob(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:

        triple_token_ids = state["triple_token_ids"]
        batch_size, triple_length = triple_token_ids.size()

        hidden = self.P(self._get_query(state))

        gate_prob = self.G(hidden)
        gen_prob = util.masked_softmax(self.W(hidden),
                                       state["action_mask"],
                                       memory_efficient=True) * gate_prob

        copy_prob = self.COPY_ATTN(hidden, state["encoded_triple"],
                                   state["triple_mask"]) * (-gate_prob + 1)
        modified_prob_list: List[torch.Tensor] = []
        for i in range(triple_length):
            copy_prob_slice = copy_prob[:, i]
            token_slice = state["triple_tokens"][:, i]
            copy_to_add_mask = token_slice != self.OOV
            copy_to_add = copy_prob_slice.min(
                copy_to_add_mask.float()).unsqueeze(-1)
            gen_prob = gen_prob.scatter_add(-1, token_slice.unsqueeze(1),
                                            copy_to_add)

            if i < (triple_length - 1):
                future_occurrences = (
                    (triple_token_ids[:, i + 1:]
                     ) == triple_token_ids[:, i].unsqueeze(-1)).float()
                future_copy_prob = copy_prob[:, i + 1:].min(future_occurrences)
                copy_prob_slice += future_copy_prob.sum(-1)

            if i > 0:
                prev_occurrences = triple_token_ids[:, :
                                                    i] == triple_token_ids[:,
                                                                           i].unsqueeze(
                                                                               -1
                                                                           )
                duplicate_mask = (prev_occurrences.sum(-1) == 0).float()
                copy_prob_slice = copy_prob_slice.min(duplicate_mask)

            left_over_copy_prob = copy_prob_slice.min(
                (~copy_to_add_mask).float())
            modified_prob_list.append(left_over_copy_prob.unsqueeze(-1))

        modified_prob_list.insert(0, gen_prob)
        modified_prob = torch.cat(modified_prob_list, dim=-1)
        return modified_prob

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["draft_mask"].size(0)
        start_predictions = state["draft_mask"].new_full((batch_size, ),
                                                         self.START)
        all_top_k_predictions, log_probabilities = self.BEAM.search(
            start_predictions, state, self.take_search_step)
        return {
            "predicted_log_probs": log_probabilities,
            "predictions": all_top_k_predictions
        }

    def _get_predicted_tokens(self,
                              predicted_indices: Union[torch.Tensor,
                                                       numpy.ndarray],
                              batch_metadata,
                              n_best: int = None):
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        predicted_tokens = []

        for top_k_predictions, metadata in zip(predicted_indices,
                                               batch_metadata):
            batch_predicted_tokens = []
            draft, triple = metadata['draft'], metadata["triple"]
            for indices in top_k_predictions[:n_best]:
                pointer, tokens = 0, []
                indices = list(indices)
                if self.END in indices:
                    indices = indices[:indices.index(self.END)]
                for index in indices:
                    if index == self.KEEP:
                        tokens.append(draft[pointer])
                        pointer += 1
                    elif index == self.DROP:
                        pointer += 1
                    elif index >= self.vocab_size:
                        adjusted_index = index - self.vocab_size
                        tokens.append(triple[adjusted_index])
                    else:
                        tokens.append(
                            str(self.vocab.get_token_from_index(index)))
                batch_predicted_tokens.append(tokens)
            if n_best == 1:
                predicted_tokens.append(batch_predicted_tokens[0])
            else:
                predicted_tokens.append(batch_predicted_tokens)
        return predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        predicted_tokens = self._get_predicted_tokens(
            output_dict["predictions"], output_dict["metadata"])
        output_dict["predicted_tokens"] = predicted_tokens
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not self.training:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        return all_metrics
Beispiel #15
0
class DropSeq2Seq(Model):
    """
    This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then
    uses the encoded representations to decode another sequence.  You can use this as the basis for
    a neural machine translation system, an abstractive summarization system, or any other common
    seq2seq problem.  The model here is simple, but should be a decent starting place for
    implementing recent models for these tasks.
    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : ``int``, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    attention : ``Attention``, optional (default = None)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    attention_function: ``SimilarityFunction``, optional (default = None)
        This is if you want to use the legacy implementation of attention. This will be deprecated
        since it consumes more memory than the specialized attention modules.
    beam_size : ``int``, optional (default = None)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : ``float``, optional (default = 0.)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015 <https://arxiv.org/abs/1506.03099>`_.
    use_bleu : ``bool``, optional (default = True)
        If True, the BLEU metric will be calculated during validation.
    """

    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention = None,
                 attention_function: SimilarityFunction = None,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.5,
                 dec_dropout: float = 0.5) -> None:
        super(DropSeq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None

        self._token_based_metric = TokenSequenceAccuracy()

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)
        self._dec_dropout = Dropout(p=dec_dropout)

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError("You can only specify an attention module or an "
                                         "attention function, but not both.")
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
        else:
            self._attention = None

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)

    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.
        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.
        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.
        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    @overrides
    def forward(self,  # type: ignore
                source_tokens: Dict[str, torch.LongTensor],
                metadata: List[Dict[str, Any]],
                target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.
        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.
        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(source_tokens)

        if target_tokens:
            state = self._init_decoder_state(state)
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, target_tokens)
        else:
            output_dict = {}

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            if target_tokens:
                # shape: (batch_size, beam_size, max_sequence_length)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_predicted_sequence_length)
                best_predictions = top_k_predictions[:, 0, :]
                predicted_tokens = self.decode(output_dict)["predicted_tokens"]
                self._token_based_metric(predicted_tokens, [x["target_tokens"] for x in metadata])
                if self._bleu:
                    self._bleu(best_predictions, target_tokens["tokens"])

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
                                for x in indices]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        encoder_outputs = self._emb_dropout(encoder_outputs)
        return {
                "source_mask": source_mask,
                "encoder_outputs": encoder_outputs,
        }

    def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
                state["encoder_outputs"],
                state["source_mask"],
                self._encoder.is_bidirectional())
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim)
        return state

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.
        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            output_dict["target_mask"] = target_mask
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
                start_predictions, state, self.take_step)

        output_dict = {
                "class_log_probabilities": log_probabilities,
                "predictions": all_top_k_predictions,
        }
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.
        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        if self._attention:
            # shape: (group_size, encoder_output_dim)
            attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask)

            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            decoder_input = torch.cat((attended_input, embedded_input), -1)
        else:
            # shape: (group_size, target_embedding_dim)
            decoder_input = embedded_input

        decoder_input = self._dec_dropout(decoder_input)

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
                decoder_input,
                (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(self._dec_dropout(decoder_hidden))

        return output_projections, state

    def _prepare_attended_input(self,
                                decoder_hidden_state: torch.LongTensor = None,
                                encoder_outputs: torch.LongTensor = None,
                                encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(
                decoder_hidden_state, encoder_outputs, encoder_outputs_mask)

        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input

    @staticmethod
    def _get_loss(logits: torch.LongTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        """
        Compute loss.
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.
        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.
        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            all_metrics.update(self._token_based_metric.get_metric(reset=reset))
            if self._bleu:
                all_metrics.update(self._bleu.get_metric(reset=reset))
        return all_metrics
class RecombinationSeq2SeqWithCopy(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 seq_metrics: Metric,
                 attention: Attention,
                 beam_size: int = None,
                 source_namespace: str = 'source_tokens',
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = False,
                 encoder_input_dropout: int = 0.0,
                 encoder_output_dropout: int = 0.0,
                 dropout=0.0,
                 feed_output_attention_to_decoder: bool = False,
                 keep_decoder_output_dim_same_as_encoder: bool = True,
                 initializer: InitializerApplicator = InitializerApplicator()) -> None:

        super(RecombinationSeq2SeqWithCopy, self).__init__(vocab)
        self._source_namespace = source_namespace
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)  # pylint: disable=protected-access

        # Evaluation Metrics
        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
        else:
            self._bleu = None
        self._seq_metric = seq_metrics

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Encoder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder
        self._encoder_output_dim = self._encoder.get_output_dim()

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention
        self._feed_output_attention_to_decoder = feed_output_attention_to_decoder
        if self._feed_output_attention_to_decoder:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._encoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # Decoder

        # Dense embedding of vocab words in the target space.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        self._num_classes = num_classes
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # TODO: relax this assumption
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._keep_decoder_output_dim_same_as_encoder = keep_decoder_output_dim_same_as_encoder
        if not self._keep_decoder_output_dim_same_as_encoder:
            self._decoder_output_dim = int(self._encoder_output_dim / 2) if encoder.is_bidirectional() \
                else self._encoder_output_dim
        else:
            self._decoder_output_dim = self._encoder_output_dim

        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        self._transform_decoder_init_state = torch.nn.Sequential(
            torch.nn.Linear(self._encoder_output_dim, self._decoder_output_dim),
            torch.nn.Tanh()
        )

        # Generate Score
        self._output_projection_layer = Linear(self._decoder_output_dim + self._encoder_output_dim, num_classes)

        # Dropout Layers
        self._encoder_input_dropout = torch.nn.Dropout(p=encoder_input_dropout)
        self._encoder_output_dropout = torch.nn.Dropout(p=encoder_output_dropout)
        self._output_dropout = torch.nn.Dropout(p=dropout)
        self._embedded_dropout = torch.nn.Dropout(p=dropout)

        initializer(self)

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor])\
            -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # pylint: disable=line-too-long
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.
        Add dropout before the softmax classifier (Following "Language to Logical Form with Neural Attention")
        Inputs are the same as for `take_step()`.

        last_predictions: (group_size,)

        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        copy_mask = (last_predictions < self._num_classes).long()
        embedded_input = self._target_embedder(last_predictions * copy_mask)

        if not self.training and copy_mask.sum() < copy_mask.size(0):
            # Copy, Retrieve target token
            mapped_indices = list()
            source_token_ids = state['source_token_ids']
            for gidx, idx in enumerate(last_predictions):
                if idx >= self._num_classes:
                    source_idx = idx - self._num_classes
                    source_token_id = int(source_token_ids[gidx,source_idx])
                    token = self.vocab.get_token_from_index(source_token_id, self._source_namespace)
                    tid = self.vocab.get_token_index(token, self._target_namespace)
                    mapped_indices.append(tid)
                else:
                    mapped_indices.append(self._pad_index)
            # mapped_indices to tensor
            mapped_indices = torch.from_numpy(numpy.array(mapped_indices))
            mapped_indices = mapped_indices.to(last_predictions.device)

            copyed_embedded_input =  self._target_embedder(mapped_indices)
            unsqueezed_copy_mask = copy_mask.unsqueeze(dim=1).float()
            embedded_input = embedded_input * unsqueezed_copy_mask + copyed_embedded_input * (1 - unsqueezed_copy_mask)

        embedded_input = self._embedded_dropout(embedded_input)

        if self._feed_output_attention_to_decoder:
            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            decoder_input = torch.cat((embedded_input, state["attention_context"]), -1)
        else:
            # shape: (group_size, target_embedding_dim)
            decoder_input = embedded_input

        # shape (decoder_hidden): (group_size, decoder_output_dim)
        # shape (decoder_context): (group_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input,
            (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # output_attended_input: shape: (group_size, encoder_output_dim)
        # attention_weights shape: (group_size, max_input_sequence_length)
        output_attended_input, attention_weights = self._prepare_output_attended_input(
            decoder_hidden,
            encoder_outputs,
            source_mask
        )
        if self._feed_output_attention_to_decoder:
            state["attention_context"] = output_attended_input

        output_projection_input = torch.cat((decoder_hidden, output_attended_input), -1)
        dropped_output_projection_input = self._output_dropout(output_projection_input)

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(dropped_output_projection_input)
        # shape: (group_size, num_classes + max_input_sequence_length)
        output_projections = torch.cat((output_projections, attention_weights), -1)

        return output_projections, state

    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes + max_input_sequence_length)
        output_projections, state = self._prepare_output_projections(last_predictions, state)

        source_mask = state['source_mask']
        group_size = source_mask.size(0)

        # (batch_size, num_classes + max_input_sequence_length)
        normalization_mask = torch.cat([source_mask.new_ones((group_size, self._num_classes)),
                                        source_mask], dim=-1)

        # shape: (group_size, num_classes + max_input_sequence_length)
        class_log_probabilities = util.masked_log_softmax(output_projections, normalization_mask, dim=-1)

        return class_log_probabilities, state

    @overrides
    def forward(self,  # type: ignore
                source_tokens: Dict[str, torch.LongTensor],
                target_tokens: Dict[str, torch.LongTensor] = None,
                target_source_token_map: torch.Tensor = None,
                meta_field: List[Dict] = None,
                ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.
        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.
        target_source_token_map: (batch_size, target_length, source_length)
        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(source_tokens)

        if target_tokens:
            state = self._init_decoder_state(state)
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, target_tokens, target_source_token_map)
        else:
            output_dict = {}

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            output_dict.update({"source_token_ids": source_tokens['tokens']})
            if target_tokens:
                # shape: (batch_size, beam_size, max_sequence_length)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_predicted_sequence_length)
                best_predictions = self.map_predictions(top_k_predictions[:, 0, :], source_tokens['tokens'], meta_field)
                if self._bleu:
                    self._bleu(best_predictions, target_tokens["tokens"])
                if self._seq_metric:
                    self._seq_metric(
                        best_predictions.float(),
                        gold_labels=target_tokens["tokens"][:, 1:].float(),
                        mask=util.get_text_field_mask(
                            target_tokens).float()[:, 1:]
                    )

        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.

        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]

            predicted_tokens = list()
            for x in indices:
                if x < self._num_classes:
                    predicted_tokens.append(self.vocab.get_token_from_index(x, namespace=self._target_namespace))
                else:
                    source_idx = x - self._num_classes
                    text = "@@copy@@%d" % int(source_idx)
                    token = Token(text)
                    # source_token_id = int(output_dict['source_token_ids'][0][source_idx])
                    # token = self.vocab.get_token_from_index(source_token_id, self._source_namespace)
                    predicted_tokens.append(token)
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        embedded_input = self._encoder_input_dropout(embedded_input)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        encoder_outputs = self._encoder_output_dropout(encoder_outputs)

        return {
            "source_token_ids": source_tokens['tokens'],
            "source_mask": source_mask,
            "encoder_outputs": encoder_outputs,
        }

    def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"],
            state["source_mask"],
            self._encoder.is_bidirectional())
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = self._transform_decoder_init_state(final_encoder_output)
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim)
        if self._feed_output_attention_to_decoder:
            state["attention_context"] = state["encoder_outputs"].new_zeros(batch_size, self._encoder_output_dim)
        return state

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      target_tokens: Dict[str, torch.LongTensor] = None,
                      target_source_token_map: torch.Tensor = None
                      ) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.

        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes + max_input_sequence_length)
            output_projections, state = self._prepare_output_projections(input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes + max_input_sequence_length)
            step_logits.append(output_projections.unsqueeze(1))

            # (batch_size, num_classes + max_input_sequence_length)
            normalization_mask = torch.cat([source_mask.new_ones((batch_size, self._num_classes)),
                                            source_mask], dim=-1)

            class_probabilities = util.masked_softmax(output_projections, normalization_mask, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes + max_input_sequence_length)
            logits = torch.cat(step_logits, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask, target_source_token_map)
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
                start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _prepare_output_attended_input(self,
                                       decoder_hidden_state: torch.Tensor = None,
                                       encoder_outputs: torch.Tensor = None,
                                       encoder_outputs_mask: torch.LongTensor = None) \
            -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply ouput attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(
            decoder_hidden_state, encoder_outputs, encoder_outputs_mask)

        normalized_weights = util.masked_softmax(input_weights, encoder_outputs_mask)

        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, normalized_weights)

        return attended_input, input_weights

    def _get_loss(self,
                  logits: torch.FloatTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor,
                  target_source_token_map: torch.Tensor) -> torch.Tensor:
        """
        Compute loss.
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.
        ``target_source_token_map``: (batch_size, target_length, source_length)

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()
        batch_size, num_decoding_steps = relevant_targets.size()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps, source_length)
        target_source_token_map = target_source_token_map[:, 1:, :]

        probs = F.softmax(logits, dim=-1)
        # (batch_size * num_decoding_steps, num_classes)
        generate_probs_flat = probs[:, :, :self._num_classes].view(-1, self._num_classes)
        relevant_targets_flat = relevant_targets.view(-1, 1).long()
        # (batch_size, num_decoding_steps)
        generate_probs = torch.gather(generate_probs_flat, dim=1, index=relevant_targets_flat).reshape(batch_size,
                                                                                                       num_decoding_steps)
        # (batch_size, num_decoding_steps)
        copy_probs = (probs[:, :, self._num_classes:] * target_source_token_map).sum(dim=-1)

        target_log_probs = torch.log(generate_probs + copy_probs + 1e-13)
        target_log_probs *= relevant_mask.float()
        negative_log_likelihood = -1 * target_log_probs
        weights_batch_sum = relevant_mask.sum(-1).float()

        per_batch_loss = negative_log_likelihood.sum(dim=1) / (weights_batch_sum + 1e-13)
        num_non_empty_sequences = ((weights_batch_sum > 0).float().sum() + 1e-13)
        return per_batch_loss.sum() / num_non_empty_sequences

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            if self._bleu:
                all_metrics.update(self._bleu.get_metric(reset=reset))
            if self._seq_metric:
                all_metrics.update(
                    {"accuracy": self._seq_metric.get_metric(reset)['accuracy']})
        return all_metrics

    def map_predictions(self, predictions: torch.LongTensor,
                        source_token_ids: torch.LongTensor,
                        meta_field: List[Dict]) -> torch.LongTensor:
        """
        Map those copy indices to target idx
        :return:
        """
        batch_size, max_length = predictions.size()
        mapped_predictions = predictions.new_full((batch_size,max_length), fill_value=self._pad_index)
        for i in range(batch_size):
            source_tokens_to_copy = meta_field[i]['source_tokens_to_copy']
            for j in range(max_length):
                idx = predictions[i, j]
                if idx < self._num_classes:
                    mapped_predictions[i, j] = idx
                else:
                    # Copy
                    source_idx = idx - self._num_classes
                    if source_idx > len(source_tokens_to_copy):
                        tid = self._pad_index
                    else:
                        token = source_tokens_to_copy[source_idx]
                        # source_token_id = int(source_token_ids[i, source_idx])
                        # token = self.vocab.get_token_from_index(source_token_id, self._source_namespace)
                        tid = self.vocab.get_token_index(token, self._target_namespace)
                    mapped_predictions[i, j] = tid
        return mapped_predictions.long()
Beispiel #17
0
class Event2Mind(Model):
    """
    This ``Event2Mind`` class is a :class:`Model` which takes an event
    sequence, encodes it, and then uses the encoded representation to decode
    several mental state sequences.

    It is based on `the paper by Rashkin et al.
    <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (``tokens``) or the target tokens can have a different namespace, in which case it needs to
        be specified as ``target_namespace``.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences.
    embedding_dropout: float, required
        The amount of dropout to apply after the source tokens have been embedded.
    encoder : ``Seq2VecEncoder``, required
        The encoder of the "encoder/decoder" model.
    max_decoding_steps : int, required
        Length of decoded sequences.
    beam_size : int, optional (default = 10)
        The width of the beam search.
    target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact'])
        Names of the target fields matching those in the ``Instance`` objects.
    target_namespace : str, optional (default = 'tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : int, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        super().__init__(vocab)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim(
        )

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(num_classes,
                                              target_embedding_dim,
                                              self._decoder_output_dim)

        self._beam_search = BeamSearch(self._end_index,
                                       beam_size=beam_size,
                                       max_steps=max_decoding_steps)

    def _update_recall(self, all_top_k_predictions: torch.Tensor,
                       target_tokens: Dict[str, torch.LongTensor],
                       target_recall: UnigramRecall) -> None:
        targets = target_tokens["tokens"]
        target_mask = get_text_field_mask(target_tokens)
        # See comment in _get_loss.
        # TODO(brendanr): Do we need contiguous here?
        relevant_targets = targets[:, 1:].contiguous()
        relevant_mask = target_mask[:, 1:].contiguous()
        target_recall(all_top_k_predictions, relevant_targets, relevant_mask,
                      self._end_index)

    def _get_num_decoding_steps(
            self, target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int:
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end
            # symbol.  Either way, we don't have to process it. (To be clear,
            # we do still output and compare against the end symbol, but there
            # is no need to take the end symbol as input to the decoder.)
            return target_sequence_length - 1
        else:
            return self._max_decoding_steps

    @overrides
    def forward(
        self,  # type: ignore
        source: Dict[str, torch.LongTensor],
        **target_tokens: Dict[str, Dict[str, torch.LongTensor]]
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the target sequences.

        Parameters
        ----------
        source : ``Dict[str, torch.LongTensor]``
            The output of ``TextField.as_array()`` applied on the source
            ``TextField``. This will be passed through a ``TextFieldEmbedder``
            and then through an encoder.
        target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``:
            Dictionary from name to output of ``Textfield.as_array()`` applied
            on target ``TextField``. We assume that the target tokens are also
            represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, embedding_dim)
        embedded_input = self._embedding_dropout(self._source_embedder(source))
        source_mask = get_text_field_mask(source)
        # (batch_size, encoder_output_dim)
        final_encoder_output = self._encoder(embedded_input, source_mask)
        output_dict = {}

        # Perform greedy search so we can get the loss.
        if target_tokens:
            if target_tokens.keys() != self._states.keys():
                target_only = target_tokens.keys() - self._states.keys()
                states_only = self._states.keys() - target_tokens.keys()
                raise Exception(
                    "Mismatch between target_tokens and self._states. Keys in "
                    +
                    f"targets only: {target_only} Keys in states only: {states_only}"
                )
            total_loss = 0
            for name, state in self._states.items():
                loss = self.greedy_search(
                    final_encoder_output=final_encoder_output,
                    target_tokens=target_tokens[name],
                    target_embedder=state.embedder,
                    decoder_cell=state.decoder_cell,
                    output_projection_layer=state.output_projection_layer)
                total_loss += loss
                output_dict[f"{name}_loss"] = loss

            # Use mean loss (instead of the sum of the losses) to be comparable to the paper.
            output_dict["loss"] = total_loss / len(self._states)

        # Perform beam search to obtain the predictions.
        if not self.training:
            batch_size = final_encoder_output.size()[0]
            for name, state in self._states.items():
                start_predictions = final_encoder_output.new_full(
                    (batch_size, ),
                    fill_value=self._start_index,
                    dtype=torch.long)
                start_state = {"decoder_hidden": final_encoder_output}

                # (batch_size, 10, num_decoding_steps)
                all_top_k_predictions, log_probabilities = self._beam_search.search(
                    start_predictions, start_state, state.take_step)

                if target_tokens:
                    self._update_recall(all_top_k_predictions,
                                        target_tokens[name], state.recall)
                output_dict[
                    f"{name}_top_k_predictions"] = all_top_k_predictions
                output_dict[
                    f"{name}_top_k_log_probabilities"] = log_probabilities

        return output_dict

    def greedy_search(self, final_encoder_output: torch.LongTensor,
                      target_tokens: Dict[str, torch.LongTensor],
                      target_embedder: Embedding, decoder_cell: GRUCell,
                      output_projection_layer: Linear) -> torch.FloatTensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the cross entropy between this sequence and ``target_tokens``.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()`` applied on some target ``TextField``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._get_num_decoding_steps(target_tokens)
        targets = target_tokens["tokens"]
        decoder_hidden = final_encoder_output
        step_logits = []
        for timestep in range(num_decoding_steps):
            # See https://github.com/allenai/allennlp/issues/1134.
            input_choices = targets[:, timestep]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
        # (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        target_mask = get_text_field_mask(target_tokens)
        return self._get_loss(logits, targets, target_mask)

    def greedy_predict(self, final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding, decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [
            final_encoder_output.new_full((batch_size, ),
                                          fill_value=self._start_index,
                                          dtype=torch.long)
        ]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]

    @staticmethod
    def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.FloatTensor:
        """
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        relevant_targets = targets[:, 1:].contiguous(
        )  # (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous(
        )  # (batch_size, num_decoding_steps)
        loss = sequence_cross_entropy_with_logits(logits, relevant_targets,
                                                  relevant_mask)
        return loss

    def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    @overrides
    def decode(
            self,
            output_dict: Dict[str,
                              torch.Tensor]) -> Dict[str, List[List[str]]]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds fields for the tokens to the ``output_dict``.
        """
        for name in self._states:
            top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][
                0]
            output_dict[f"{name}_top_k_predicted_tokens"] = [
                self.decode_all(top_k_predicted_indices)
            ]

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics = {}
        # Recall@10 needs beam search which doesn't happen during training.
        if not self.training:
            for name, state in self._states.items():
                all_metrics[name] = state.recall.get_metric(reset=reset)
        return all_metrics
Beispiel #18
0
class WNCGTransformerModel(WNCGBaseModel):
    def __init__(self,
                 input_gpv_dim,
                 d_model,
                 nhead,
                 num_layers,
                 n_vocab,
                 input_amedas_seqlen,
                 weather,
                 cl,
                 lr=0.001,
                 dropout_p=0.2,
                 warm_up_steps=4000):
        super().__init__(d_model, weather, cl, dropout_p)
        """ Encoder """
        # encoder for gpv
        self.gpv_encoder = MLPEncoder(input_gpv_dim,
                                      d_model,
                                      dropout_p=dropout_p)
        self.pos_encoder = PositionalEncoding(d_model)
        # encoder for amedas
        self.amedas_to_dmodel = nn.Linear(input_amedas_seqlen, d_model)
        # encoder for meta-data
        metaenc = {}
        metaenc["area"] = nn.Embedding(277, d_model)
        metaenc["month"] = nn.Embedding(12, d_model)
        metaenc["day"] = nn.Embedding(31, d_model)
        metaenc["time"] = nn.Embedding(24, d_model)
        metaenc["week"] = nn.Embedding(7, d_model)
        self.meta_encoders = nn.ModuleDict(metaenc)
        # encoder
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead),
            num_layers=num_layers)
        """ Decoder """
        # word decoder
        self.token_embedder = nn.Embedding(n_vocab,
                                           d_model,
                                           padding_idx=IDs.PAD.value)
        self.token_position = PositionalEncoding(d_model)
        self.token_decoder = nn.TransformerDecoder(nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead),
                                                   num_layers=num_layers)
        self.token_output = nn.Linear(d_model, n_vocab)

        # make the arguments global
        self.lr = lr
        # weather label
        self.weather = weather
        # content agreement loss
        self.cl = cl
        # warm up steps for learning rate
        self.warm_up_steps = warm_up_steps
        # save the arguments
        self.save_hyperparameters()

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def encode(self, src_gpv, src_amedas, src_meta, src_comment):
        """ encode """
        # encode gpv-data
        src_gpv = self.gpv_encoder(src_gpv)
        src_gpv = self.pos_encoder(src_gpv)
        # encode amedas-data
        src_amedas = self.amedas_to_dmodel(src_amedas)
        # encode meta-data
        emb_area = self.meta_encoders["area"](src_meta[0, :])
        emb_month = self.meta_encoders["month"](src_meta[1, :])
        emb_day = self.meta_encoders["day"](src_meta[2, :])
        emb_time = self.meta_encoders["time"](src_meta[3, :])
        emb_week = self.meta_encoders["week"](src_meta[4, :])
        src_meta = torch.stack(
            [emb_area, emb_month, emb_day, emb_time, emb_week], dim=0)
        # concatenate input-data (gpv, amedas, meta)
        src_data = torch.cat(
            [src_gpv, src_amedas, src_meta], dim=0
        )  # (seq_len[9(gpv) + 4(amedas) + 5(meta)], batch_size, d_model)
        # encode input-data by transformer
        src_memory = self.transformer_encoder(
            src_data)  # (seq_len, batch_size, d_model)
        return src_gpv, src_amedas, src_meta, src_memory

    def forward(self, src_gpv, src_amedas, src_meta, src_comment):
        """[summary]

        Args:
            src_gpv ([type]): [description]
            src_amedas ([type]): [description]
            src_meta ([type]): [description]
            src_comment ([type]): [description]

        Returns:
            [type]: [description]
        """
        """ encode gpv/amedas/meta """
        src_gpv, src_amedas, src_meta, src_memory = \
            self.encode(src_gpv, src_amedas, src_meta, src_comment)

        # initialize outputs of weather labels and weather hidden
        ZERO = torch.zeros(1, 1).to(self.device)
        sunny_out, cloudy_out, rain_out, snow_out, weather_hidden = \
            ZERO, ZERO, ZERO, ZERO, ZERO, None
        """ decode weather labels """
        if self.weather == "label":
            sunny_out, sunny_hidden = self.sunny_decoder(src_memory[0])
            cloudy_out, cloudy_hidden = self.cloudy_decoder(src_memory[0])
            rain_out, rain_hidden = self.rain_decoder(src_memory[0])
            snow_out, snow_hidden = self.snow_decoder(src_memory[0])
            weather_hidden = torch.stack(
                [sunny_hidden, cloudy_hidden, rain_hidden, snow_hidden], dim=0)

        # induce weather_hidden into input
        if self.weather is not None:
            src_memory = torch.cat([src_memory, weather_hidden], dim=0)
        """ decode tokens """
        # prepare masks for word decoder
        src_comment_len = src_comment.size(0)  # seq_len
        # mask for padding token
        src_comment_padd_mask = (src_comment == IDs.PAD.value).transpose(
            0, 1).to(self.device)  # (batch_size, seq_len)
        # mask for subsequence
        src_comment_attn_mask = self.generate_square_subsequent_mask(
            src_comment_len).to(self.device)  # (seq_len, seq_len)
        # embedding
        src_comment_emb = self.token_embedder(
            src_comment)  # (seqlen, batch_size, d_model)
        src_comment_emb_pos = self.token_position(src_comment_emb)
        # decode
        token_hidden = self.token_decoder(
            src_comment_emb_pos,
            src_memory,
            tgt_mask=src_comment_attn_mask,
            tgt_key_padding_mask=src_comment_padd_mask
        )  # (seqlen, batch_size, d_model)
        # output distribution over vocabularies
        token_out = self.token_output(token_hidden)

        return (F.log_softmax(token_out, dim=-1), \
            F.log_softmax(sunny_out, dim=-1), F.log_softmax(cloudy_out, dim=-1), \
            F.log_softmax(rain_out, dim=-1), F.log_softmax(snow_out, dim=-1),
            weather_hidden, src_comment_emb)

    # learning rate warm-up
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                       optimizer_closure, on_tpu, using_native_amp,
                       using_lbfgs):
        """[summary]

        Args:
            epoch ([type]): [description]
            batch_idx ([type]): [description]
            optimizer ([type]): [description]
            optimizer_idx ([type]): [description]
            optimizer_closure ([type]): [description]
            on_tpu ([type]): [description]
            using_native_amp ([type]): [description]
            using_lbfgs ([type]): [description]
        """
        # warm up lr
        for pg in optimizer.param_groups:
            self.lr = (self.hparams.d_model**-0.5) * min(
                float(self.trainer.global_step + 1)**-0.5,
                float(self.trainer.global_step + 1) * self.warm_up_steps**-1.5)
            pg['lr'] = self.lr
        # update params
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),
                                     lr=self.lr,
                                     betas=(0.9, 0.98),
                                     eps=1e-09)
        return optimizer

    def greedy_token_decode(self,
                            src_memory,
                            gpv_output,
                            amedas_output,
                            meta_output,
                            weather_hidden,
                            token_generation_limit=128):
        """ decode tokens """
        _, batch_size, d_model = src_memory.size()
        src_comment = torch.tensor([[IDs.BOS.value
                                     for _ in range(batch_size)]],
                                   dtype=torch.long).to(self.device)
        decoded_batch = torch.zeros((batch_size, token_generation_limit))

        # induce weather_hidden into input
        if self.weather is not None:
            src_memory = torch.cat([src_memory, weather_hidden], dim=0)

        for idx in range(token_generation_limit):
            # prepare masks for word decoder
            src_comment_len = idx + 1  # seq_len
            # mask for padding token
            src_comment_padd_mask = (src_comment == IDs.PAD.value).transpose(
                0, 1).to(self.device)  # (batch_size, seq_len)
            # mask for subsequence
            src_comment_attn_mask = self.generate_square_subsequent_mask(
                src_comment_len).to(self.device)  # (seq_len, seq_len)
            # embedding
            src_comment_emb = self.token_embedder(
                src_comment)  # (seqlen, batch_size, d_model)
            src_comment_emb = self.token_position(src_comment_emb)
            # decode
            token_hidden = self.token_decoder(
                src_comment_emb,
                src_memory,
                tgt_mask=src_comment_attn_mask,
                tgt_key_padding_mask=src_comment_padd_mask
            )  # (seqlen, batch_size, d_model)
            # output distribution over vocabularies
            token_out = self.token_output(token_hidden)

            topv, topi = token_out[-1, :, :].data.topk(1)
            decoded_batch[:, idx] = topi.view(-1)
            topi = topi.transpose(0, 1)

            # concat source with output
            src_comment = torch.cat([src_comment, topi], dim=0)

        return decoded_batch.detach().tolist()

    @torch.no_grad()
    def beam_token_decode(self,
                          src_memory,
                          gpv_output,
                          amedas_output,
                          meta_output,
                          weather_hidden,
                          beam_width=5):
        max_steps = 128  # The maximum number of decoding steps to take,
        self.beam_search = BeamSearch(end_index=IDs.EOS.value,
                                      max_steps=max_steps,
                                      beam_size=beam_width)
        batch_size = src_memory.size(1)

        # induce weather_hidden into input
        if self.weather is not None:
            src_memory = torch.cat([src_memory, weather_hidden], dim=0)

        start_predictions = torch.tensor([IDs.BOS.value] * batch_size,
                                         dtype=torch.long,
                                         device=self.device)
        start_state = {
            "prev_tokens":
            torch.zeros(batch_size, 0, dtype=torch.long,
                        device=self.device),  # set none of prev_tokens
            "decoder_hidden":
            src_memory  # (seq_len, batch_size, d_model)
        }

        def step(last_tokens, current_state, t):
            """
            Args:
                last_tokens: (group_size,)
                current_state: {}
                t: int
            """
            # concatenate prev_tokens with last_tokens
            prev_tokens = torch.cat(
                [current_state["prev_tokens"],
                 last_tokens.unsqueeze(1)],
                dim=-1)  # [batch_size * beam_width, t+1]
            # embedding
            prev_tokens_emb = self.token_embedder(prev_tokens).transpose(
                0, 1)  # (seqlen, batch_size, d_model)
            prev_tokens_emb = self.token_position(prev_tokens_emb)
            prev_tokens_len = prev_tokens.size(1)
            # mask for padding token
            prev_token_padd_mask = (prev_tokens == IDs.PAD.value).to(
                self.device)  # (batch_size, seq_len)
            # mask for subsequence
            prev_token_attn_mask = self.generate_square_subsequent_mask(
                prev_tokens_len).to(self.device)  # (seq_len, seq_len)
            # decode
            token_hidden = self.token_decoder(
                prev_tokens_emb,
                current_state["decoder_hidden"],
                tgt_mask=prev_token_attn_mask,
                tgt_key_padding_mask=prev_token_padd_mask
            )  # (seqlen, batch_size, d_model)

            # output distribution over vocabularies
            token_out = self.token_output(token_hidden)
            # get outout distribution for last token
            decoder_output = F.log_softmax(token_out[-1, :, :], dim=-1)
            # update prev_tokens
            current_state["prev_tokens"] = prev_tokens
            return (decoder_output, current_state)

        predictions, log_probs = self.beam_search.search(
            start_predictions=start_predictions,
            start_state=start_state,
            step=step)

        return predictions, log_probs
Beispiel #19
0
class MSPointerNetwork(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder_1: TextFieldEmbedder,
                 source_encoder_1: Seq2SeqEncoder,
                 beam_size: int,
                 max_decoding_steps: int,
                 decoder_output_dim: int,
                 target_embedding_dim: int = 30,
                 namespace: str = "tokens",
                 tensor_based_metric: Metric = None,
                 align_embeddings: bool = True,
                 source_embedder_2: TextFieldEmbedder = None,
                 source_encoder_2: Seq2SeqEncoder = None) -> None:
        super().__init__(vocab)
        self._source_embedder_1 = source_embedder_1
        self._source_embedder_2 = source_embedder_1 or self._source_embedder_1
        self._source_encoder_1 = source_encoder_1
        self._source_encoder_2 = source_encoder_2 or self._source_encoder_1

        self._source_namespace = namespace
        self._target_namespace = namespace

        self.encoder_output_dim_1 = self._source_encoder_1.get_output_dim()
        self.encoder_output_dim_2 = self._source_encoder_2.get_output_dim()
        self.cated_encoder_out_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2
        self.decoder_output_dim = decoder_output_dim

        # TODO: AllenNLP实现的Addictive Attention可能没有bias
        self._attention_1 = AdditiveAttention(self.decoder_output_dim,
                                              self.encoder_output_dim_1)
        self._attention_2 = AdditiveAttention(self.decoder_output_dim,
                                              self.encoder_output_dim_2)

        if not align_embeddings:
            self.target_embedding_dim = target_embedding_dim
            self._target_vocab_size = self.vocab.get_vocab_size(
                namespace=self._target_namespace)
            self._target_embedder = Embedding(self._target_vocab_size,
                                              target_embedding_dim)
        else:
            self._target_embedder = self._source_embedder_1._token_embedders[
                "tokens"]
            self._target_vocab_size = self.vocab.get_vocab_size(
                namespace=self._target_namespace)
            self.target_embedding_dim = self._target_embedder.get_output_dim()

        self.decoder_input_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2 + \
                                 self.target_embedding_dim

        self._decoder_cell = LSTMCell(self.decoder_input_dim,
                                      self.decoder_output_dim)

        # 用于将两个encoder的最后隐层状态映射成解码器初始状态
        self._encoder_out_projection_layer = torch.nn.Linear(
            in_features=self.cated_encoder_out_dim,
            out_features=self.decoder_output_dim
        )  #  TODO: bias - true of false?

        # 软门控机制参数,用于计算lambda
        self._gate_projection_layer = torch.nn.Linear(
            in_features=self.decoder_output_dim + self.decoder_input_dim,
            out_features=1,
            bias=False)

        self._start_index = self.vocab.get_token_index(START_SYMBOL, namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     namespace)
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        self._tensor_based_metric = tensor_based_metric or \
            BLEU(exclude_indices={self._pad_index, self._end_index, self._start_index})

    def _encode(
            self, source_tokens_1: Dict[str, torch.Tensor],
            source_tokens_2: Dict[str,
                                  torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        分别将source1和source2的token ids经过encoder编码,输出各自的mask以及encoder_out。
        同时token_ids信息也会附加。
        """

        # 1. 编码source1
        # shape: (batch_size, seq_max_len_1)
        source_mask_1 = util.get_text_field_mask(source_tokens_1)
        # shape: (batch_size, seq_max_len_1, encoder_input_dim_1)
        embedder_out_1 = self._source_embedder_1(source_tokens_1)
        # shape: (batch_size, seq_max_len_1, encoder_output_dim_1)
        encoder_out_1 = self._source_encoder_1(embedder_out_1, source_mask_1)

        # 2. 编码source2
        # shape: (batch_size, seq_max_len_2)
        source_mask_2 = util.get_text_field_mask(source_tokens_2)
        # shape: (batch_size, seq_max_len_2, encoder_input_dim_2)
        embedder_out_2 = self._source_embedder_2(source_tokens_2)
        # shape: (batch_size, seq_max_len_2, encoder_input_dim_2)
        encoder_out_2 = self._source_encoder_2(embedder_out_2, source_mask_2)

        return {
            "source_mask_1": source_mask_1,
            "source_mask_2": source_mask_2,
            "source_token_ids_1": source_tokens_1["tokens"],
            "source_token_ids_2": source_tokens_2["tokens"],
            "encoder_out_1": encoder_out_1,
            "encoder_out_2": encoder_out_2,
        }

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        初始化decoder:更新传入的state,使之带有decoder的context和hidden向量。
                      其中hidden向量(h_0)通过两个编码器的最终隐层状态经过一个
                      映射得到,context初始化为0向量。
        """
        batch_size = state["encoder_out_1"].size()[0]

        # 根据每个batch的mask情况,获取最终rnn隐层状态
        # shape: (batch_size, encoder_output_dim_1)
        encoder_final_output_1 = util.get_final_encoder_states(
            state["encoder_out_1"], state["source_mask_1"],
            self._source_encoder_1.is_bidirectional())
        # shape: (batch_size, encoder_output_dim_2)
        encoder_final_output_2 = util.get_final_encoder_states(
            state["encoder_out_2"], state["source_mask_2"],
            self._source_encoder_2.is_bidirectional())

        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = torch.relu(
            self._encoder_out_projection_layer(
                torch.cat([encoder_final_output_1, encoder_final_output_2],
                          dim=-1)))
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["decoder_hidden"].new_zeros(
            batch_size, self.decoder_output_dim)

        return state

    @overrides
    def forward(
        self,
        source_tokens_1: Dict[str, torch.LongTensor],
        source_tokens_2: Dict[str, torch.LongTensor],
        metadata: List[Dict[str, Any]],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:

        # 分成训练、验证/测试、预测,这三种情况分别考虑

        # 1. 训练时:必然同时提供了target_tokens作为ground truth。
        #    此时,只需要计算loss,无需beam search
        if self.training:
            assert target_tokens is not None

            state = self._encode(source_tokens_1, source_tokens_2)
            state["target_token_ids"] = target_tokens["tokens"]
            state = self._init_decoder_state(state)
            output_dict = self._forward_loss(target_tokens, state)
            output_dict["metadata"] = metadata
            return output_dict  # 包含loss、metadata两项

        # 2. 验证/测试时:self.training为false,但是提供了target_tokens。
        #    此时,需要计算loss、运行beam search、计算评价指标
        elif target_tokens:

            # 计算loss
            state = self._encode(source_tokens_1, source_tokens_2)
            state["target_token_ids"] = target_tokens["tokens"]
            state = self._init_decoder_state(state)
            output_dict = self._forward_loss(target_tokens, state)

            # 运行beam search
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

            # 计算评价指标(BLEU)
            if self._tensor_based_metric is not None:
                # shape: (batch_size, beam_size, max_decoding_steps)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_decoding_steps)
                best_predictions = top_k_predictions[:, 0, :]
                # shape: (batch_size, target_seq_len)
                gold_tokens = target_tokens["tokens"]
                self._tensor_based_metric(best_predictions, gold_tokens)
            output_dict["metadata"] = metadata
            return output_dict  # 包含loss、metadata、top-k、top-k log prob四项

        # 3. 预测时:self.training为false,同时也没有提供target_tokens。
        #    此时,只需要运行beam search执行top-k预测即可
        else:
            state = self._encode(source_tokens_1, source_tokens_2)
            state = self._init_decoder_state(state)
            output_dict = {"metadata": metadata}
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            return output_dict  # 包含metadata、top-k、top-k log prob三项

    def _forward_loss(
            self, target_tokens: Dict[str, torch.Tensor],
            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        为输入的一个batch计算损失(仅在训练时调用)。
        """
        batch_size, target_seq_len = target_tokens["tokens"].size()

        # shape: (batch_size, seq_max_len_1)
        source_mask_1 = state["source_mask_1"]
        # shape: (batch_size, seq_max_len_2)
        source_mask_2 = state["source_mask_2"]

        # 需要生成的最大步数永远比目标序列(<start> ... <end>)的最大长度少1步
        num_decoding_steps = target_seq_len - 1

        step_log_likelihoods = []  # 存放每个时间步,目标词的log似然值
        for timestep in range(num_decoding_steps):  # t: 0..T

            # 当前时刻要输入的token id,shape (batch_size,)
            input_choices = target_tokens["tokens"][:, timestep]

            # 更新一步解码器状态(计算各类中间变量,例如attention分数、软门控分数)
            state = self._decoder_step(input_choices, state)

            # 获取decoder_hidden相对于两个source的attention分数
            # shape: (batch_size, seq_max_len_1)
            attentive_weights_1 = state["attentive_weights_1"]
            # shape: (batch_size, seq_max_len_2)
            attentive_weights_2 = state["attentive_weights_2"]

            # 计算target_to_source,指明当前要输出的target (ground truth),是否出现在source之中
            # shape: (batch_size, seq_max_len_1)
            target_to_source_1 = (state["source_token_ids_1"] ==
                                  state["target_token_ids"][:, timestep +
                                                            1].unsqueeze(-1))
            # shape: (batch_size, seq_max_len_2)
            target_to_source_2 = (state["source_token_ids_2"] ==
                                  state["target_token_ids"][:, timestep +
                                                            1].unsqueeze(-1))

            # 根据上面的信息计算当前时间步target token的对数似然
            step_log_likelihood = self._get_ll_contrib(
                attentive_weights_1, attentive_weights_2, source_mask_1,
                source_mask_2, target_to_source_1, target_to_source_2,
                state["target_token_ids"][:,
                                          timestep + 1], state["gate_score"])
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))

        # 将各个时间步的对数似然合并成一个tensor
        # shape: (batch_size, num_decoding_steps = target_seq_len - 1)
        log_likelihoods = torch.cat(step_log_likelihoods, 1)

        # 获取包含START和END的target mask
        # shape: (batch_size, target_seq_len)
        target_mask = util.get_text_field_mask(target_tokens)

        # 去掉第一个,不会作为目标词的START
        # shape: (batch_size, num_decoding_steps = target_seq_len - 1)
        target_mask = target_mask[:, 1:].float()

        # 将各个时间步上的对数似然tensor使用mask累加,得到整个时间序列的对数似然
        log_likelihood = (log_likelihoods * target_mask).sum(dim=-1)

        loss = -log_likelihood.sum() / batch_size

        return {"loss": loss}

    def _decoder_step(
            self, last_predictions: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        更新一步decoder状态。
        """

        # shape: (group_size, seq_max_len_1, encoder_output_dim_1)
        source_mask_1 = state["source_mask_1"].float()
        # shape: (group_size, seq_max_len_2, encoder_output_dim_2)
        source_mask_2 = state["source_mask_2"].float()
        # y_{t-1}, shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        # a_t, shape: (group_size, seq_max_len_1)
        state["attentive_weights_1"] = self._attention_1(
            state["decoder_hidden"], state["encoder_out_1"], source_mask_1)
        # a'_t, shape: (group_size, seq_max_len_2)
        state["attentive_weights_2"] = self._attention_2(
            state["decoder_hidden"], state["encoder_out_2"], source_mask_2)

        # c_t, shape: (group_size, encoder_output_dim_1)
        attentive_read_1 = util.weighted_sum(state["encoder_out_1"],
                                             state["attentive_weights_1"])
        # c'_t, shape: (group_size, encoder_output_dim_2)
        attentive_read_2 = util.weighted_sum(state["encoder_out_2"],
                                             state["attentive_weights_2"])

        # 计算软门控机制:lambda
        # shape: (group_size, target_embedding_dim + encoder_output_dim_1 + encoder_output_dim_2 + decoder_output_dim)
        gate_input = torch.cat((embedded_input, attentive_read_1,
                                attentive_read_2, state["decoder_hidden"]),
                               dim=-1)
        # shape: (group_size,)
        gate_projected = self._gate_projection_layer(gate_input).squeeze(-1)
        # shape: (group_size,)
        state["gate_score"] = torch.sigmoid(gate_projected)

        # shape: (group_size, target_embedding_dim + encoder_output_dim_1 + encoder_output_dim_2)
        decoder_input = torch.cat(
            (embedded_input, attentive_read_1, attentive_read_2), dim=-1)

        # 更新decoder状态(hidden和context/cell)
        state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
            decoder_input, (state["decoder_hidden"], state["decoder_context"]))

        return state

    def _get_ll_contrib(self, copy_scores_1: torch.Tensor,
                        copy_scores_2: torch.Tensor,
                        source_mask_1: torch.Tensor,
                        source_mask_2: torch.Tensor,
                        target_to_source_1: torch.Tensor,
                        target_to_source_2: torch.Tensor,
                        target_tokens: torch.Tensor,
                        gate_score: torch.Tensor) -> torch.Tensor:
        """
        根据一个时间步的attention分数、黄金token,计算黄金token的对数似然。

        参数:
            - copy_scores_1:对第一个source的注意力分值。
                    shape: (batch_size, seq_max_len_1)
            - copy_scores_2:对第二个source的注意力分值。
                    shape: (batch_size, seq_max_len_2)
            - source_mask_1:第一个source的mask
                    shape: (batch_size, seq_max_len_1)
            - source_mask_2:第二个source的mask
                    shape: (batch_size, seq_max_len_2)
            - target_to_source_1:目标词是否为第一个source对应位置的词
                    shape: (batch_size, seq_max_len_1)
            - target_to_source_2:目标词是否为第二个source对应位置的词
                    shape: (batch_size, seq_max_len_2)
            - target_tokens:当前时间步的目标词
                    shape: (batch_size,)
            - gate_score:从第一个source拷贝词语的概率(0-1之间)
                    shape: (batch_size,)

        返回:
            当前时间步,生成目标词的对数似然(log-likelihood)
                    shape: (batch_size,)
        """
        # 计算第一个source的分值
        # shape: (batch_size, seq_max_len_1)
        combined_log_probs_1 = (copy_scores_1 + 1e-45).log() + (
            target_to_source_1.float() +
            1e-45).log() + (source_mask_1.float() + 1e-45).log()
        # shape: (batch_size,)
        log_probs_1 = util.logsumexp(
            combined_log_probs_1)  # log(exp(a[0]) + ... + exp(a[L]))

        # 计算第二个source的分值
        # shape: (batch_size, seq_max_len_2)
        combined_log_probs_2 = (copy_scores_2 + 1e-45).log() + (
            target_to_source_2.float() +
            1e-45).log() + (source_mask_2.float() + 1e-45).log()
        # shape: (batch_size,)
        log_probs_2 = util.logsumexp(
            combined_log_probs_2)  # log(exp(a[0]) + ... + exp(a[L]))

        # 计算 log(p1 * gate + p2 * (1-gate))
        log_gate_score_1 = gate_score.log()  # shape: (batch_size,)
        log_gate_score_2 = (1 - gate_score).log()  # shape: (batch_size,)
        item_1 = (log_gate_score_1 + log_probs_1).unsqueeze(
            -1)  # shape: (batch_size, 1)
        item_2 = (log_gate_score_2 + log_probs_2).unsqueeze(
            -1)  # shape: (batch_size, 1)
        step_log_likelihood = util.logsumexp(torch.cat(
            (item_1, item_2), -1))  # shape: (batch_size,)
        return step_log_likelihood

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask_1"].size()[0]
        start_predictions = state["source_mask_1"].new_full(
            (batch_size, ), fill_value=self._start_index)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_search_step)
        return {
            "predicted_log_probs": log_probabilities,
            "predictions": all_top_k_predictions
        }

    def take_search_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        用于beam_search。

        参数:
            - last_predictions:上一时间步的预测结果
                    shape: (group_size,)
            - state:状态
        
        返回:
            - final_log_probs:在全词表上的对数似然
                    shape: (group_size, target_vocab_size)
            - state:更新后的状态

        说明:该函数用于提供给Beam Search使用,输入为上一个时间步的预测id(last_predictions,
              初始为start_index),输出为全词表上的对数似然概率(final_log_probs)。
        
        TODO: 考虑OOV情况(需要整体大改)
        """
        # 更新一步decoder状态
        state = self._decoder_step(last_predictions, state)

        # 对第一个source的拷贝概率值,shape: (group_size, seq_max_len_1)
        copy_scores_1 = state["attentive_weights_1"]
        # 对第二个source的拷贝概率值,shape: (group_size, seq_max_len_2)
        copy_scores_2 = state["attentive_weights_2"]
        # 概率值的门控,shape: (group_size,)
        gate_score = state["gate_score"]

        # 计算全词表上的对数似然
        final_log_probs = self._gather_final_log_probs(copy_scores_1,
                                                       copy_scores_2,
                                                       gate_score, state)

        return final_log_probs, state

    def _gather_final_log_probs(
            self, copy_scores_1: torch.Tensor, copy_scores_2: torch.Tensor,
            gate_score: torch.Tensor,
            state: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        根据三个概率,计算全词表上的对数似然。

        参数:
            - copy_scores_1:第一个source的复制概率(经过归一化)
                    shape: (group_size, seq_max_len_1)
            - copy_scores_2:第二个source的复制概率(经过归一化)
                    shape: (group_size, seq_max_len_2)
            - gate_score:门控的分数,决定source1共享多少比例(source2即贡献1-gate_score)
                    shape: (group_size,)
            - state:当前时间步,更新后的解码状态
        
        返回:
            - final_log_probs:全词表上的概率
                    shape: (group_size, target_vocab_size)
        """
        # 获取group_size和两个序列的长度
        group_size, seq_max_len_1 = copy_scores_1.size()
        group_size, seq_max_len_2 = copy_scores_2.size()

        # TODO: 这里默认了source和target使用同一个词表映射,否则需要source2target的映射
        #      (即source词在target词表的index),才能进行匹配
        # shape: (group_size, seq_max_len_1)
        source_token_ids_1 = state["source_token_ids_1"]
        # shape: (group_size, seq_max_len_2)
        source_token_ids_2 = state["source_token_ids_2"]

        # 在序列上扩展gate_score
        # 需要和source1相乘的gate概率,shape: (group_size, seq_max_len_1)
        gate_1 = gate_score.expand(seq_max_len_1, -1).t()
        # 需要和source2相乘的gate概率,shape: (group_size, seq_max_len_2)
        gate_2 = (1 - gate_score).expand(seq_max_len_2, -1).t()

        # 加权后的source1分值,shape: (group_size, seq_max_len_1)
        copy_scores_1 = copy_scores_1 * gate_1
        # 加权后的source2分值,shape: (group_size, seq_max_len_2)
        copy_scores_2 = copy_scores_2 * gate_2

        # shape: (group_size, seq_max_len_1)
        log_probs_1 = (copy_scores_1 + 1e-45).log()
        # shape: (group_size, seq_max_len_2)
        log_probs_2 = (copy_scores_2 + 1e-45).log()

        # 初始化全词表上的概率为全0, shape: (group_size, target_vocab_size)
        final_log_probs = (state["decoder_hidden"].new_zeros(
            (group_size, self._target_vocab_size)) + 1e-45).log()

        for i in range(seq_max_len_1):  # 遍历source1的所有时间步
            # 当前时间步的预测概率,shape: (group_size, 1)
            log_probs_slice = log_probs_1[:, i].unsqueeze(-1)
            # 当前时间步的token ids,shape: (group_size, 1)
            source_to_target_slice = source_token_ids_1[:, i].unsqueeze(-1)

            # 选出要更新位置,原有的词表概率,shape: (group_size, 1)
            selected_log_probs = final_log_probs.gather(
                -1, source_to_target_slice)
            # 更新后的概率值(原有概率+更新概率,混合),shape: (group_size, 1)
            combined_scores = util.logsumexp(
                torch.cat((selected_log_probs, log_probs_slice),
                          dim=-1)).unsqueeze(-1)
            # 将combined_scores设置回final_log_probs中
            final_log_probs = final_log_probs.scatter(-1,
                                                      source_to_target_slice,
                                                      combined_scores)

        # 对source2也同样做一遍
        for i in range(seq_max_len_2):
            log_probs_slice = log_probs_2[:, i].unsqueeze(-1)
            source_to_target_slice = source_token_ids_2[:, i].unsqueeze(-1)
            selected_log_probs = final_log_probs.gather(
                -1, source_to_target_slice)
            combined_scores = util.logsumexp(
                torch.cat((selected_log_probs, log_probs_slice),
                          dim=-1)).unsqueeze(-1)
            final_log_probs = final_log_probs.scatter(-1,
                                                      source_to_target_slice,
                                                      combined_scores)

        return final_log_probs

    def _get_predicted_tokens(
            self,
            predicted_indices: Union[torch.Tensor, numpy.ndarray],
            batch_metadata: List[Any],
            n_best: int = None) -> List[Union[List[List[str]], List[str]]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        predicted_tokens: List[Union[List[List[str]], List[str]]] = []
        for top_k_predictions, metadata in zip(predicted_indices,
                                               batch_metadata):
            batch_predicted_tokens: List[List[str]] = []
            for indices in top_k_predictions[:n_best]:
                tokens: List[str] = []
                indices = list(indices)
                if self._end_index in indices:
                    indices = indices[:indices.index(self._end_index)]
                for index in indices:
                    token = self.vocab.get_token_from_index(
                        index, self._target_namespace)
                    tokens.append(token)
                batch_predicted_tokens.append(tokens)
            if n_best == 1:
                predicted_tokens.append(batch_predicted_tokens[0])
            else:
                predicted_tokens.append(batch_predicted_tokens)
        return predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        将预测结果(tensor)解码成token序列。
        """
        predicted_tokens = self._get_predicted_tokens(
            output_dict["predictions"], output_dict["metadata"])
        output_dict["predicted_tokens"] = predicted_tokens
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            if self._tensor_based_metric is not None:
                all_metrics.update(
                    self._tensor_based_metric.get_metric(reset=reset))
        return all_metrics
Beispiel #20
0
class WNCGRnnModel(WNCGBaseModel):
    def __init__(self,
                 input_gpv_dim,
                 d_model,
                 num_layers,
                 n_vocab,
                 input_amedas_seqlen,
                 weather,
                 cl,
                 lr=0.001,
                 dropout_p=0.2):
        super().__init__(d_model, weather, cl, dropout_p)
        d_meta_amedas = 64
        self.num_layers = num_layers
        self.d_model = d_model
        """ Encoder """
        # encoder for gpv
        self.gpv_encoder = MLPEncoder(
            input_gpv_dim, d_model,
            dropout_p=dropout_p)  # single-layer MLP for encoding a gpv data
        self.rnn_encoder = nn.GRU(
            d_model, d_model, num_layers=1, bidirectional=True
        )  # single-layer BiGRU for encoding sequence of gpv data
        self.gpv_to_dmodel = nn.Linear(d_model * 2, d_model)
        # encoder for amedas
        self.amedas_to_dmodel = nn.Linear(input_amedas_seqlen, d_meta_amedas)
        # encoder for meta-data
        metaenc = {}
        metaenc["area"] = nn.Embedding(277, d_meta_amedas)
        metaenc["month"] = nn.Embedding(12, d_meta_amedas)
        metaenc["day"] = nn.Embedding(31, d_meta_amedas)
        metaenc["time"] = nn.Embedding(24, d_meta_amedas)
        metaenc["week"] = nn.Embedding(7, d_meta_amedas)
        self.meta_encoders = nn.ModuleDict(metaenc)
        # BiGRU for gpv, Linear for Amedas, Linear for Metadata
        self.input_to_dmodel = nn.Linear(
            (d_model * 2) + (d_meta_amedas * 4) + (d_meta_amedas * 5), d_model)
        self.relu = nn.ReLU()
        """ Decoder """
        # word decoder
        self.token_decoder = TokenAttnDecoderRNN(d_model, d_meta_amedas,
                                                 d_model, n_vocab, num_layers,
                                                 weather, dropout_p)
        # weather label
        self.weather = weather
        # option for content agreement loss
        self.cl = cl
        # make the arguments global
        self.lr = lr
        # save the arguments
        self.save_hyperparameters()

    def encode(self, src_gpv, src_amedas, src_meta, src_comment):
        """ encode 
        """
        _, batch_size = src_comment.size()
        # encode gpv-data
        src_gpv = self.gpv_encoder(src_gpv)
        gpv_output, gpv_hidden = self.rnn_encoder(src_gpv)
        gpv_output = self.gpv_to_dmodel(gpv_output)
        # encode amedas-data
        src_amedas = self.amedas_to_dmodel(src_amedas)
        # encode meta-data
        emb_area = self.meta_encoders["area"](src_meta[0, :])
        emb_month = self.meta_encoders["month"](src_meta[1, :])
        emb_day = self.meta_encoders["day"](src_meta[2, :])
        emb_time = self.meta_encoders["time"](src_meta[3, :])
        emb_week = self.meta_encoders["week"](src_meta[4, :])
        src_meta = torch.stack(
            [emb_area, emb_month, emb_day, emb_time, emb_week], dim=0)

        gpv_hidden = torch.cat([gpv_output[0, :, :], gpv_output[-1, :, :]],
                               dim=1)  # (batch_size, d_model * 2 * 2)
        amedas_hidden = src_amedas.transpose(0, 1).reshape(
            batch_size, -1)  # (batch_size, num_amedas_types * d_model)
        meta_hidden = src_meta.transpose(0, 1).reshape(
            batch_size, -1)  # (batch_size, num_meta_types * d_model)

        # initital state of decoder
        data_h = self.relu(
            self.input_to_dmodel(
                torch.cat([gpv_hidden, amedas_hidden, meta_hidden],
                          dim=1)))  # (batch_size, d_model)
        encoder_hidden = self.reset(data_h)
        return gpv_output, src_amedas, src_meta, encoder_hidden

    def reset(self, hidden_state):
        # initialize hidden states of word decoder
        batch_size = hidden_state.size(0)
        decoder_hidden = torch.zeros(
            (self.num_layers, batch_size, self.d_model),
            dtype=torch.float32).to(self.device)
        nn.init.normal_(decoder_hidden, mean=0, std=0.05)
        decoder_hidden[0, :, :] = hidden_state
        return decoder_hidden

    def forward(self, src_gpv, src_amedas, src_meta, src_comment):
        """[summary]

        Args:
            src_gpv ([type]): [description]
            src_amedas ([type]): [description]
            src_meta ([type]): [description]
            src_comment ([type]): [description]

        Returns:
            [type]: [description]
        """
        """ encode GPV/AMeDAS/Meta"""
        gpv_output, amedas_output, meta_output, encoder_hidden = \
            self.encode(src_gpv, src_amedas, src_meta, src_comment)

        # initialize outputs of weather labels and weather hidden
        ZERO = torch.zeros(1, 1).to(self.device)
        sunny_out, cloudy_out, rain_out, snow_out, weather_hidden = \
            ZERO, ZERO, ZERO, ZERO, None
        """ decode weather labels """
        if self.weather == "label":
            sunny_out, sunny_hidden = self.sunny_decoder(encoder_hidden[0])
            cloudy_out, cloudy_hidden = self.cloudy_decoder(encoder_hidden[0])
            rain_out, rain_hidden = self.rain_decoder(encoder_hidden[0])
            snow_out, snow_hidden = self.snow_decoder(encoder_hidden[0])
            weather_hidden = torch.stack(
                [sunny_hidden, cloudy_hidden, rain_hidden, snow_hidden], dim=0)
        """ decode tokens """
        token_out = []
        tgt_word_embeddings = []
        hidden = encoder_hidden  # initial state of decoder
        for word_input in src_comment:
            output, hidden, word_emb = self.token_decoder(
                word_input, hidden, gpv_output, amedas_output, meta_output,
                weather_hidden)
            token_out.append(output)
            tgt_word_embeddings.append(word_emb)

        token_out = torch.stack(token_out, dim=0)
        tgt_text_embed = torch.cat(tgt_word_embeddings, dim=0)

        return (F.log_softmax(token_out, dim=-1), \
            F.log_softmax(sunny_out, dim=-1), F.log_softmax(cloudy_out, dim=-1), \
            F.log_softmax(rain_out, dim=-1), F.log_softmax(snow_out, dim=-1), \
            weather_hidden, tgt_text_embed)

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                       optimizer_closure, on_tpu, using_native_amp,
                       using_lbfgs):
        """[summary]

        Args:
            epoch ([type]): [description]
            batch_idx ([type]): [description]
            optimizer ([type]): [description]
            optimizer_idx ([type]): [description]
            optimizer_closure ([type]): [description]
            on_tpu ([type]): [description]
            using_native_amp ([type]): [description]
            using_lbfgs ([type]): [description]
        """
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        lr_scheduler = ReduceLROnPlateau(optimizer,
                                         'min',
                                         factor=0.5,
                                         patience=1,
                                         min_lr=1e-5,
                                         verbose=True)
        scheduler = {
            'scheduler': lr_scheduler,  # The LR scheduler instance (required)
            'interval': 'epoch',  # The unit of the scheduler's step size
            'frequency': 1,  # The frequency of the scheduler
            'reduce_on_plateau': True,  # For ReduceLROnPlateau scheduler
            'monitor': 'val_loss',  # Metric for ReduceLROnPlateau to monitor
            'strict':
            True,  # Whether to crash the training if `monitor` is not found
            'name': None,  # Custom name for LearningRateMonitor to use
        }
        return [optimizer], [scheduler]

    def greedy_token_decode(self,
                            hidden,
                            gpv_output,
                            amedas_output,
                            meta_output,
                            weather_hidden,
                            token_generation_limit=128):
        _, batch_size, hidden_size = hidden.size()
        decoded_batch = torch.zeros((batch_size, token_generation_limit))
        word_input = torch.tensor([[IDs.BOS.value] for _ in range(batch_size)],
                                  dtype=torch.long).to(self.device)
        for idx in range(token_generation_limit):
            output, hidden, _ = self.token_decoder(word_input, hidden,
                                                   gpv_output, amedas_output,
                                                   meta_output, weather_hidden)
            topv, topi = output.data.topk(
                1)  # [batch_size, vocab_size] get candidates
            decoded_batch[:, idx] = topi.view(-1)
            word_input = topi
        return decoded_batch.detach().tolist()

    @torch.no_grad()
    def beam_token_decode(self,
                          hidden,
                          gpv_output,
                          amedas_output,
                          meta_output,
                          weather_hidden,
                          beam_width=5):
        max_steps = 128  # The maximum number of decoding steps to take,
        self.beam_search = BeamSearch(end_index=IDs.EOS.value,
                                      max_steps=max_steps,
                                      beam_size=beam_width)
        batch_size = hidden.size(1)

        start_predictions = torch.tensor([IDs.BOS.value] * batch_size,
                                         dtype=torch.long,
                                         device=self.device)
        start_state = {
            "prev_tokens":
            torch.zeros(batch_size, 0, dtype=torch.long, device=self.device),
            "decoder_hidden":
            hidden
        }

        def step(last_tokens, current_state, t):
            """
            Args:
                last_tokens: (group_size,)
                current_state: {}
                t: int
            """
            nonlocal gpv_output
            nonlocal amedas_output
            nonlocal meta_output
            nonlocal weather_hidden
            group_size = last_tokens.size(0)
            # cocatenate prev_tokens with last_tokens
            prev_tokens = torch.cat(
                [current_state["prev_tokens"],
                 last_tokens.unsqueeze(1)],
                dim=-1)  # [B*k, t+1]

            # expand context hiddens for beam search decoding
            if group_size != gpv_output.size(1):
                gpv_output = gpv_output.unsqueeze(2)\
                    .expand(gpv_output.size(0), gpv_output.size(1), beam_width, gpv_output.size(-1))\
                    .reshape(gpv_output.size(0), gpv_output.size(1) * beam_width, gpv_output.size(-1))
                amedas_output = amedas_output.unsqueeze(2)\
                    .expand(amedas_output.size(0), amedas_output.size(1), beam_width, amedas_output.size(-1))\
                    .reshape(amedas_output.size(0), amedas_output.size(1) * beam_width, amedas_output.size(-1))
                meta_output = meta_output.unsqueeze(2)\
                    .expand(meta_output.size(0), meta_output.size(1), beam_width, meta_output.size(-1))\
                    .reshape(meta_output.size(0), meta_output.size(1) * beam_width, meta_output.size(-1))
                weather_hidden = weather_hidden.unsqueeze(2)\
                    .expand(weather_hidden.size(0), weather_hidden.size(1), beam_width, weather_hidden.size(-1))\
                    .reshape(weather_hidden.size(0), weather_hidden.size(1) * beam_width, weather_hidden.size(-1)) if weather_hidden is not None else None

            # decode for one step using decoder
            decoder_output, decoder_hidden, _ = self.token_decoder(
                prev_tokens[:, -1], current_state["decoder_hidden"],
                gpv_output, amedas_output, meta_output, weather_hidden)

            current_state["prev_tokens"] = prev_tokens  # update prev_tokens
            current_state[
                "decoder_hidden"] = decoder_hidden  # update decoder_hidden

            return (decoder_output, current_state)

        predictions, log_probs = self.beam_search.search(
            start_predictions=start_predictions,
            start_state=start_state,
            step=step)

        return predictions, log_probs
Beispiel #21
0
class Seq2seqPlmsGenerator(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 pretrained_model_path,
                 beam_size=5,
                 max_decoding_steps=140,
                 indexer=None):
        super().__init__(vocab)
        self.plm = MT5ForConditionalGeneration.from_pretrained(pretrained_model_path)
        self._indexer = indexer or PretrainedTransformerIndexer(pretrained_model_path, namespace="tokens")
        ##
        self._start_id = self.plm.config.decoder_start_token_id
        ##
        self._end_id = self.plm.config.eos_token_id  #
        self._decoder_start_id = self.plm.config.decoder_start_token_id
        self._end_id = self.plm.config.eos_token_id  #
        self._pad_id = self.plm.config.pad_token_id  #

        self._beam_search = BeamSearch(
            self._end_id, max_steps=max_decoding_steps, beam_size=beam_size or 1
        )
        self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})

    @overrides
    def forward(self,
                source_tokens,
                target_tokens=None) -> Dict[str, torch.Tensor]:
        inputs = source_tokens
        targets = target_tokens
        input_ids, input_mask = inputs["tokens"]["token_ids"], inputs["tokens"]["mask"]

        outputs = {}

        # If no targets are provided, then shift input to right by 1. Bart already does this internally
        # but it does not use them for loss calculation.
        if targets is not None:
            target_ids, target_mask = targets["tokens"]["token_ids"], targets["tokens"]["mask"]
        else:
            target_ids = input_ids[:, 1:]
            target_mask = input_mask[:, 1:]

        if self.training: # training
            outputs = self.plm(input_ids=input_ids, attention_mask=input_mask,
                               decoder_input_ids=target_ids[:, :-1].contiguous(),
                               decoder_attention_mask=target_mask[:, :-1].contiguous(),
                               use_cache=False, return_dict=True)
            outputs["decoder_logits"] = outputs.logits
            outputs["loss"] = sequence_cross_entropy_with_logits(
                outputs.logits,
                cast(torch.LongTensor, target_ids[:, 1:].contiguous()),
                cast(torch.BoolTensor, target_mask[:, 1:].contiguous()),
                label_smoothing=0.1,
                average="token",
            )
        elif targets is not None: # validation
            outputs = self.plm(input_ids=input_ids, attention_mask=input_mask,
                               decoder_input_ids=target_ids[:, :-1].contiguous(),
                               decoder_attention_mask=target_mask[:, :-1].contiguous(),
                               use_cache=False, return_dict=True)
            outputs["decoder_logits"] = outputs.logits
            outputs["loss"] = sequence_cross_entropy_with_logits(
                outputs.logits,
                cast(torch.LongTensor, target_ids[:, 1:].contiguous()),
                cast(torch.BoolTensor, target_mask[:, 1:].contiguous()),
                label_smoothing=0.1,
            )
            self._rouge(torch.argmax(outputs.logits, -1), target_ids)
            self._bleu(torch.argmax(outputs.logits, -1), target_ids)
        else: #prediction
            # Use decoder start id and start of sentence to start decoder
            initial_decoder_ids = torch.tensor(
                [[self._decoder_start_id]],
                dtype=input_ids.dtype,
                device=input_ids.device,
            ).repeat(input_ids.shape[0], 1)

            inital_state = {
                "input_ids": input_ids,
                "input_mask": input_mask,
            }
            beam_result = self._beam_search.search(
                initial_decoder_ids, inital_state, self.take_step
            )

            predictions = beam_result[0]
            logger.info(beam_result)

            max_pred_indices = (
                beam_result[1].argmax(dim=-1).view(-1, 1, 1).expand(-1, -1, predictions.shape[-1])
            )
            predictions = predictions.gather(dim=1, index=max_pred_indices).squeeze(dim=1)

            self._rouge(predictions, target_ids)
            self._bleu(predictions, target_ids)

            outputs["predictions"] = predictions
            outputs["log_probabilities"] = (
                beam_result[1].gather(dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1)
            )

            self.make_output_human_readable(outputs)

        return outputs

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        if not self.training:
            metrics.update(self._rouge.get_metric(reset=reset))
            metrics.update(self._bleu.get_metric(reset=reset))
        return metrics

    @staticmethod
    def _decoder_cache_to_dict(decoder_cache: DecoderCacheType) -> Dict[str, torch.Tensor]:
        cache_dict = {}
        for layer_index, layer_cache in enumerate(decoder_cache):
            # Each layer caches the key and value tensors for its self-attention and cross-attention.
            # Hence the `layer_cache` tuple has 4 elements.
            assert len(layer_cache) == 4
            for tensor_index, tensor in enumerate(layer_cache):
                key = f"decoder_cache_{layer_index}_{tensor_index}"
                cache_dict[key] = tensor
        return cache_dict

    def _dict_to_decoder_cache(self, cache_dict: Dict[str, torch.Tensor]) -> DecoderCacheType:
        decoder_cache = []
        for layer_index in range(self.plm.config.num_layers):
            base_key = f"decoder_cache_{layer_index}_"
            layer_cache = (
                cache_dict[base_key + "0"],
                cache_dict[base_key + "1"],
                cache_dict[base_key + "2"],
                cache_dict[base_key + "3"],
            )
            decoder_cache.append(layer_cache)
        assert decoder_cache
        return tuple(decoder_cache)

    def take_step(
            self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take step during beam search.
        # Parameters
        last_predictions : `torch.Tensor`
            The predicted token ids from the previous step. Shape: `(group_size,)`
        state : `Dict[str, torch.Tensor]`
            State required to generate next set of predictions
        step : `int`
            The time step in beam search decoding.
        # Returns
        `Tuple[torch.Tensor, Dict[str, torch.Tensor]]`
            A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and
            an updated state dictionary.
        """
        if len(last_predictions.shape) == 1:
            last_predictions = last_predictions.unsqueeze(-1)

        decoder_cache = None
        decoder_cache_dict = {
            k: state[k].contiguous()
            for k in state
            if k not in {"input_ids", "input_mask", "encoder_states"}
        }
        if len(decoder_cache_dict) != 0:
            decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict)

        encoder_outputs = (state["encoder_states"],) if "encoder_states" in state else None
        outputs = self.plm(
            input_ids=state["input_ids"] if encoder_outputs is None else None,
            attention_mask=state["input_mask"],
            encoder_outputs=encoder_outputs,
            decoder_input_ids=last_predictions,
            past_key_values=decoder_cache,
            use_cache=True,
            return_dict=True,
        )

        logits = outputs.logits[:, -1, :]
        log_probabilities = F.log_softmax(logits, dim=-1)

        decoder_cache = outputs.past_key_values
        if decoder_cache is not None:
            decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache)
            state.update(decoder_cache_dict)

        state["encoder_states"] = outputs.encoder_last_hidden_state

        return log_probabilities, state

    @overrides
    def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        # Parameters
        output_dict : `Dict[str, torch.Tensor]`
            A dictionary containing a batch of predictions with key `predictions`. The tensor should have
            shape `(batch_size, max_sequence_length)`
        # Returns
        `Dict[str, Any]`
            Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of
            tokens.
        """
        predictions = output_dict["predictions"]
        predicted_tokens = [None] * predictions.shape[0]
        for i in range(predictions.shape[0]):
            predicted_tokens[i] = self._indexer.indices_to_tokens(
                {"token_ids": predictions[i].tolist()},
                self.vocab,
            )
        output_dict["predicted_tokens"] = predicted_tokens  # type: ignore
        output_dict["predicted_text"] = self._indexer._tokenizer.batch_decode(
            predictions.tolist(), skip_special_tokens=True
        )

        return output_dict
Beispiel #22
0
class MachampSeq2SeqDecoder(Model):
    """
    An autoregressive decoder that can be used for most seq2seq tasks.

    # Parameters

    vocab : `Vocabulary`, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    max_decoding_steps : `int`
        Maximum length of decoded sequences.
    attention : `Attention`, optional (default = `None`)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    target_namespace : `str`, optional (default = `'tokens'`)
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : `int`, optional (default = `'source_embedding_dim'`)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    beam_size : `int`, optional (default = `None`)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : `float`, optional (default = `0.`)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        [Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015](https://arxiv.org/abs/1506.03099).
    use_bleu : `bool`, optional (default = `True`)
        If True, the BLEU metric will be calculated during validation.
    ngram_weights : `Iterable[float]`, optional (default = `(0.25, 0.25, 0.25, 0.25)`)
        Weights to assign to scores for each ngram size.
    """
    def __init__(
        self,
        task: str,
        vocab: Vocabulary,
        input_dim: int,
        max_decoding_steps: int,
        loss_weight: float = 1.0,
        attention: Attention = None,
        beam_size: int = None,
        target_namespace: str = "target_tokens",
        target_embedding_dim: int = None,
        scheduled_sampling_ratio: float = 0.0,
        use_bleu: bool = True,
        bleu_ngram_weights: Iterable[float] = (0.25, 0.25, 0.25, 0.25),
        target_decoder_layers: int = 1,
        **kwargs,
    ) -> None:

        super().__init__(vocab, **kwargs)

        self.task = task
        self.vocab = vocab
        self.loss_weight = loss_weight
        self._target_namespace = task + '_target_words'
        self._target_decoder_layers = target_decoder_layers
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)
            self._bleu = BLEU(bleu_ngram_weights,
                              exclude_indices={
                                  pad_index, self._end_index, self._start_index
                              })
        else:
            self._bleu = None
        self.metrics = {"bleu": self._bleu}

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        num_classes = self.vocab.get_vocab_size(
            namespace=self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention

        # The input to the decoder is just the previous target embedding.
        target_embedding_dim = target_embedding_dim or self._encoder_output_dim
        self._decoder_input_dim = target_embedding_dim

        # Dense embedding of vocab words in the target space.
        self._target_embedder = Embedding(num_embeddings=num_classes,
                                          embedding_dim=target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = input_dim
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
        else:
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        if self._target_decoder_layers > 1:
            self._decoder_cell = LSTM(
                self._decoder_input_dim,
                self._decoder_output_dim,
                self._target_decoder_layers,
            )
        else:
            self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                          self._decoder_output_dim)
        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

    @overrides
    def forward(
            self,  # type: ignore
            embedded_text: torch.LongTensor,
            source_mask: torch.LongTensor,
            target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]:

        state = {"encoder_outputs": embedded_text, "source_mask": source_mask}
        if target_tokens:
            state = self._init_decoder_state(state)
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, target_tokens)
        else:
            output_dict = {}

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            if target_tokens and self._bleu:
                # shape: (batch_size, beam_size, max_sequence_length)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_predicted_sequence_length)
                best_predictions = top_k_predictions[:, 0, :]
                self._bleu(best_predictions, target_tokens["tokens"]["tokens"])

        return output_dict

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str, Any]) -> Dict[str, Any]:
        """
        Finalize predictions.
        This method overrides `Model.make_output_human_readable`, which gets called after `Model.forward`, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the `forward` method.
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called `predicted_tokens` to the `output_dict`.
        """
        predicted_indices = output_dict  #["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for top_k_predictions in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # we want top-k results.
            if len(top_k_predictions.shape) == 1:
                top_k_predictions = [top_k_predictions]

            batch_predicted_tokens = []
            for indices in top_k_predictions:
                indices = list(indices)
                # Collect indices till the first end_symbol
                if self._end_index in indices:
                    indices = indices[:indices.index(self._end_index)]
                predicted_tokens = [
                    self.vocab.get_token_from_index(
                        x, namespace=self._target_namespace) for x in indices
                ]
                batch_predicted_tokens.append(predicted_tokens)

            all_predicted_tokens.append(batch_predicted_tokens)
        return all_predicted_tokens

    def take_step(self, last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor],
                  step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.
        # Parameters
        last_predictions : `torch.Tensor`
            A tensor of shape `(group_size,)`, which gives the indices of the predictions
            during the last time step.
        state : `Dict[str, torch.Tensor]`
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape `(group_size, *)`, where `*` can be any other number
            of dimensions.
        step : `int`
            The time step in beam search decoding.
        # Returns
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of `(log_probabilities, updated_state)`, where `log_probabilities`
            is a tensor of shape `(group_size, num_classes)` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while `updated_state` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.
        Notes
        -----
            We treat the inputs as a batch, even though `group_size` is not necessarily
            equal to `batch_size`, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"],
            state["source_mask"],
            bidirectional=False)
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(
            batch_size, self._decoder_output_dim)
        if self._target_decoder_layers > 1:
            # shape: (num_layers, batch_size, decoder_output_dim)
            state["decoder_hidden"] = (
                state["decoder_hidden"].unsqueeze(0).repeat(
                    self._target_decoder_layers, 1, 1))

            # shape: (num_layers, batch_size, decoder_output_dim)
            state["decoder_context"] = (
                state["decoder_context"].unsqueeze(0).repeat(
                    self._target_decoder_layers, 1, 1))

        return state

    def _forward_loop(
            self,
            state: Dict[str, torch.Tensor],
            target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.
        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),
                                                fill_value=self._start_index,
                                                dtype=torch.long)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []

        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(
                input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {
            "predictions": predictions,
            "class_probabilities": predictions
        }

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets,
                                  target_mask) * self.loss_weight
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index, dtype=torch.long)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _prepare_output_projections(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.
        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (num_layers, group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (num_layers, group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        if self._attention:
            # shape: (group_size, encoder_output_dim)
            if self._target_decoder_layers > 1:
                attended_input = self._prepare_attended_input(
                    decoder_hidden[0], encoder_outputs, source_mask)
            else:
                attended_input = self._prepare_attended_input(
                    decoder_hidden, encoder_outputs, source_mask)
            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            decoder_input = torch.cat((attended_input, embedded_input), -1)
        else:
            # shape: (group_size, target_embedding_dim)
            decoder_input = embedded_input

        if self._target_decoder_layers > 1:
            # shape: (1, batch_size, target_embedding_dim)
            #TODO why is this necessary?
            decoder_input = decoder_input.unsqueeze(0).contiguous()
            decoder_context = decoder_context.contiguous()
            decoder_hidden = decoder_hidden.contiguous()

            # shape (decoder_hidden): (num_layers, batch_size, decoder_output_dim)
            # shape (decoder_context): (num_layers, batch_size, decoder_output_dim)
            # TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells.
            with torch.cuda.amp.autocast(False):
                _, (decoder_hidden, decoder_context) = self._decoder_cell(
                    decoder_input.float(),
                    (decoder_hidden.float(), decoder_context.float()))
        else:
            # shape (decoder_hidden): (batch_size, decoder_output_dim)
            # shape (decoder_context): (batch_size, decoder_output_dim)
            # TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells.
            with torch.cuda.amp.autocast(False):
                decoder_hidden, decoder_context = self._decoder_cell(
                    decoder_input.float(),
                    (decoder_hidden.float(), decoder_context.float()))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # shape: (group_size, num_classes)
        if self._target_decoder_layers > 1:
            output_projections = self._output_projection_layer(
                decoder_hidden[-1])
        else:
            output_projections = self._output_projection_layer(decoder_hidden)
        return output_projections, state

    def _prepare_attended_input(
        self,
        decoder_hidden_state: torch.LongTensor = None,
        encoder_outputs: torch.LongTensor = None,
        encoder_outputs_mask: torch.BoolTensor = None,
    ) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(decoder_hidden_state, encoder_outputs,
                                        encoder_outputs_mask)

        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input

    @staticmethod
    def _get_loss(
        logits: torch.LongTensor,
        targets: torch.LongTensor,
        target_mask: torch.BoolTensor,
    ) -> torch.Tensor:
        """
        Compute loss.
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.
        The length of `targets` is expected to be greater than that of `logits` because the
        decoder does not need to compute the output corresponding to the last timestep of
        `targets`. This method aligns the inputs appropriately to compute the loss.
        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        main_metrics: Dict[str, float] = {}
        if self._bleu:  # and not self.training:
            main_metrics = {
                f".run/{self.task}/{metric_name}": metric.get_metric(reset)
                for metric_name, metric in self.metrics.items()
            }
        return {**main_metrics}
Beispiel #23
0
class Bart(Model):
    """
    BART model from the paper "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation,
    Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language
    modeling head and thus can be used for text generation.
    """
    def __init__(
        self,
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
    ):
        """
        # Parameters

        model_name : `str`, required
            Name of the pre-trained BART model to use. Available options can be found in
            `transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP`.
        vocab : `Vocabulary`, required
            Vocabulary containing source and target vocabularies.
        indexer : `PretrainedTransformerIndexer`, optional (default = `None`)
            Indexer to be used for converting decoded sequences of ids to to sequences of tokens.
        max_decoding_steps : `int`, optional (default = `128`)
            Number of decoding steps during beam search.
        beam_size : `int`, optional (default = `5`)
            Number of beams to use in beam search. The default is from the BART paper.
        encoder : `Seq2SeqEncoder`, optional (default = `None`)
            Encoder to used in BART. By default, the original BART encoder is used.
        """
        super().__init__(vocab)
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
                    self.bart.config.hidden_size)
            self.bart.model.encoder = _BartEncoderWrapper(
                encoder,
                self.bart.model.encoder.embed_tokens,
                self.bart.model.encoder.embed_positions,
            )

    @overrides
    def forward(
            self,
            source_tokens: TextFieldTensors,
            target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]:
        """
        Performs the forward step of Bart.

        # Parameters

        source_tokens : `TextFieldTensors`, required
            The source tokens for the encoder. We assume they are stored under the `tokens` key.
        target_tokens : `TextFieldTensors`, optional (default = `None`)
            The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target
            tokens are given, the source tokens are shifted to the right by 1.


        # Returns

        `Dict[str, torch.Tensor]`
            During training, this dictionary contains the `decoder_logits` of shape `(batch_size,
            max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions`
            of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`.

        """
        inputs = source_tokens
        targets = target_tokens
        input_ids, input_mask = inputs["tokens"]["token_ids"], inputs[
            "tokens"]["mask"]

        outputs = {}

        # If no targets are provided, then shift input to right by 1. Bart already does this internally
        # but it does not use them for loss calculation.
        if targets is not None:
            target_ids, target_mask = targets["tokens"]["token_ids"], targets[
                "tokens"]["mask"]
        else:
            target_ids = input_ids[:, 1:]
            target_mask = input_mask[:, 1:]

        if self.training:
            decoder_logits = self.bart(
                input_ids=input_ids,
                attention_mask=input_mask,
                decoder_input_ids=target_ids[:, :-1].contiguous(),
                decoder_attention_mask=target_mask[:, :-1].contiguous(),
                use_cache=False,
            )[0]

            outputs["decoder_logits"] = decoder_logits

            # The BART paper mentions label smoothing of 0.1 for sequence generation tasks
            outputs["loss"] = sequence_cross_entropy_with_logits(
                decoder_logits,
                target_ids[:, 1:].contiguous(),
                target_mask[:, 1:].contiguous(),
                label_smoothing=0.1,
                average="token",
            )
        else:
            # Use decoder start id and start of sentence to start decoder
            initial_decoder_ids = torch.tensor(
                [[self._decoder_start_id, self._start_id]],
                dtype=input_ids.dtype,
                device=input_ids.device,
            ).repeat(input_ids.shape[0], 1)

            inital_state = {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "encoder_states": None,
            }
            beam_result = self._beam_search.search(initial_decoder_ids,
                                                   inital_state,
                                                   self.take_step)

            predictions = beam_result[0]
            max_pred_indices = (beam_result[1].argmax(dim=-1).view(
                -1, 1, 1).expand(-1, -1, predictions.shape[-1]))
            predictions = predictions.gather(
                dim=1, index=max_pred_indices).squeeze(dim=1)

            self._rouge(predictions, target_ids)
            self._bleu(predictions, target_ids)

            outputs["predictions"] = predictions
            outputs["log_probabilities"] = (beam_result[1].gather(
                dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1))

            self.make_output_human_readable(outputs)

        return outputs

    @staticmethod
    def _decoder_cache_to_dict(decoder_cache):
        cache_dict = {}
        for layer_index, layer_cache in enumerate(decoder_cache):
            for attention_name, attention_cache in layer_cache.items():
                for tensor_name, cache_value in attention_cache.items():
                    key = (layer_index, attention_name, tensor_name)
                    cache_dict[key] = cache_value
        return cache_dict

    @staticmethod
    def _dict_to_decoder_cache(cache_dict):
        decoder_cache = []
        for key, cache_value in cache_dict.items():
            # Split key and extract index and dict keys
            layer_idx, attention_name, tensor_name = key
            # Extend decoder_cache to fit layer_idx + 1 layers
            decoder_cache = decoder_cache + [
                {} for _ in range(layer_idx + 1 - len(decoder_cache))
            ]
            cache = decoder_cache[layer_idx]
            if attention_name not in cache:
                cache[attention_name] = {}
            assert tensor_name not in cache[attention_name]
            cache[attention_name][tensor_name] = cache_value
        return decoder_cache

    def take_step(self, last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor],
                  step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take step during beam search.

        # Parameters

        last_predictions : `torch.Tensor`
            The predicted token ids from the previous step. Shape: `(group_size,)`
        state : `Dict[str, torch.Tensor]`
            State required to generate next set of predictions
        step : `int`
            The time step in beam search decoding.


        # Returns

        `Tuple[torch.Tensor, Dict[str, torch.Tensor]]`
            A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and
            an updated state dictionary.
        """
        if len(last_predictions.shape) == 1:
            last_predictions = last_predictions.unsqueeze(-1)

        # Only the last predictions are needed for the decoder, but we need to pad the decoder ids
        # to not mess up the positional embeddings in the decoder.
        padding_size = 0
        if step > 0:
            padding_size = step + 1
            padding = torch.full(
                (last_predictions.shape[0], padding_size),
                self._pad_id,
                dtype=last_predictions.dtype,
                device=last_predictions.device,
            )
            last_predictions = torch.cat([padding, last_predictions], dim=-1)

        decoder_cache = None
        decoder_cache_dict = {
            k: (state[k].contiguous() if state[k] is not None else None)
            for k in state
            if k not in {"input_ids", "input_mask", "encoder_states"}
        }
        if len(decoder_cache_dict) != 0:
            decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict)

        log_probabilities = None
        for i in range(padding_size, last_predictions.shape[1]):
            encoder_outputs = ((state["encoder_states"], ) if
                               state["encoder_states"] is not None else None)
            outputs = self.bart(
                input_ids=state["input_ids"],
                attention_mask=state["input_mask"],
                encoder_outputs=encoder_outputs,
                decoder_input_ids=last_predictions[:, :i + 1],
                past_key_values=decoder_cache,
                use_cache=True,
            )

            decoder_log_probabilities = F.log_softmax(outputs[0][:, 0], dim=-1)

            if log_probabilities is None:
                log_probabilities = decoder_log_probabilities
            else:
                idx = last_predictions[:, i].view(-1, 1)
                log_probabilities = decoder_log_probabilities + log_probabilities.gather(
                    dim=-1, index=idx)

            decoder_cache = outputs[1]

            state["encoder_states"] = outputs[2]

        if decoder_cache is not None:
            decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache)
            state.update(decoder_cache_dict)

        return log_probabilities, state

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """

        # Parameters

        output_dict : `Dict[str, torch.Tensor]`
            A dictionary containing a batch of predictions with key `predictions`. The tensor should have
            shape `(batch_size, max_sequence_length)`

        # Returns

        `Dict[str, Any]`
            Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of
            tokens.

        """
        predictions = output_dict["predictions"]
        predicted_tokens = [None] * predictions.shape[0]
        for i in range(predictions.shape[0]):
            predicted_tokens[i] = self._indexer.indices_to_tokens(
                {"token_ids": predictions[i].tolist()}, self.vocab)
        output_dict["predicted_tokens"] = predicted_tokens

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        if not self.training:
            metrics.update(self._rouge.get_metric(reset=reset))
            metrics.update(self._bleu.get_metric(reset=reset))
        return metrics
class EncoderDecoder(Model):
    def __init__(self,
                 source_embedder: TextFieldEmbedder,
                 target_embedder: TextFieldEmbedder,
                 max_steps: int,
                 encoder: Encoder,
                 decoder: Decoder,
                 hidden_size: int,
                 vocab: Vocabulary,
                 teacher_force_ratio: float,
                 regularizer: RegularizerApplicator = None) -> None:
        super().__init__(vocab, regularizer)
        # TODO: Workon BeamSearch, try to switch to OpenNMT BeamSearch but implement our own beamsearch first
        self.max_steps = max_steps
        self.hidden_size = hidden_size
        self.source_embedder = source_embedder
        self.target_embedder = target_embedder
        self.encoder = encoder
        self.decoder = decoder
        self.teacher_force_ratio = teacher_force_ratio
        self.decoder.add_vocab(self.vocab)
        self.padding_idx = self.vocab.get_token_index(DEFAULT_PADDING_TOKEN)
        self.start_idx = self.vocab.get_token_index(START_SYMBOL)
        self.end_idx = self.vocab.get_token_index(END_SYMBOL)
        self.unk_idx = self.vocab.get_token_index(DEFAULT_OOV_TOKEN)
        self.beam = BeamSearch(self.end_idx,
                               max_steps=self.max_steps,
                               beam_size=5)
        self.criterion = CrossEntropyLoss(ignore_index=self.padding_idx)

    # noinspection PyMethodMayBeStatic
    def init_enc_state(
            self,
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        source_mask = util.get_text_field_mask(source_tokens)
        source_lengths = get_lengths_from_binary_sequence_mask(source_mask)
        state = {
            'source_mask': source_mask,  # (B, L)
            'source_lengths': source_lengths,  # (L)
            'source_tokens': source_tokens['tokens'],
        }
        return state

    def init_dec_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        states = state['encoder_states']
        batch_size = states.size(0)
        length = states.size(1)
        state['context'] = states.new_zeros((batch_size, 1, self.hidden_size))
        state['dec_state'] = state['hidden']
        state['coverage'] = states.new_zeros(
            (batch_size, length, 1))  # (B, L, 1)
        return state

    def forward(self,
                source_tokens: Dict[str, torch.Tensor],
                source_text: Dict[str, Any],
                source_ids: Dict[str, torch.Tensor],
                target_tokens: Dict[str, torch.Tensor] = None,
                saliency_values: torch.Tensor = None) \
            -> Dict[str, torch.Tensor]:
        """
        The forward function of the encoder and decoder model

        :param source_ids: The source ids that is unique to the document
        :param source_text: The raw text of source sequence
        :param saliency_values: The saliency values for source tokens
        :param source_tokens: Indexes of states tokens
        :param target_tokens: Indexes of target tokens
        :return: The loss and prediction of the model
        """
        state = self._encode(source_tokens)
        output_dict = {}

        if target_tokens:
            state = self._decode(source_ids, target_tokens, state)
            output_dict['loss'] = self._compute_loss(target_tokens, state)

        if not self.training and not target_tokens:
            output_dict['predictions'] = self._forward_beam_search(
                state, source_ids)
        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor],
            source_ids: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        state = self.init_dec_state(state)
        state['source_ids'] = source_ids['ids']
        state['max_oov'] = source_ids['max_oov']
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self.start_idx)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self.beam.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        emb = self.target_embedder({'tokens': last_predictions})
        state = self.decoder(emb, state)
        return Softmax(dim=-1)(state['class_logits'].squeeze(1)).log(), state

    def _encode(
            self,
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Encode the states tokens

        :param source_tokens: The indexes of states tokens
        :return: All the states and the last state
        """
        state = self.init_enc_state(source_tokens)
        # (Batch, Seq, Emb Dim)
        embedded_src = self.source_embedder(source_tokens)

        # final_state = (last state, last context)
        states, final_state = self.encoder(embedded_src,
                                           state['source_lengths'])
        state['encoder_states'] = states  # (B, L, Num Direction * D_h)
        state['hidden'] = final_state  # (B, L, Num Direction * D_h)
        assert state['encoder_states'].size(2) == (2 * self.hidden_size)
        return state

    def _decode(self, source_ids: Dict[str, torch.Tensor],
                target_tokens: Dict[str, torch.Tensor],
                state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Decode the encoder state
        :param target_tokens: The indexes of target tokens
        :param enc_states: All the encoder states
        :param enc_state: The last encoder state
        :return: The output of decoder, attentions and last decoding state
        """
        state = self.init_dec_state(state)
        state['source_ids'] = source_ids['ids']
        state['max_oov'] = source_ids['max_oov']
        all_class_logits = []
        all_coverages = []
        all_attentions = []
        # Teacher Forcing
        if torch.rand(1).item() <= self.teacher_force_ratio:
            embedded_tgt = self.target_embedder(target_tokens)
            for step, emb in enumerate(embedded_tgt.split(1, dim=1)):
                state = self.decoder(emb, state)
                all_class_logits.append(state['class_logits'])
                all_coverages.append(state['coverage'])
                all_attentions.append(state['attention'])
        else:
            tokens = state["encoder_states"].new_full(
                (state["encoder_states"].size(0), ),
                fill_value=self.start_idx,
                dtype=torch.long)
            emb = self.target_embedder({'tokens': tokens})
            for step in range(self.max_steps):
                state = self.decoder(emb, state)
                all_class_logits.append(state['class_logits'])
                all_coverages.append(state['coverage'])
                all_attentions.append(state['attention'])
                # prob_dist = Categorical(Softmax(dim=-1)(all_class_logits[-1]))
                # tokens = prob_dist.sample()
                _, tokens = torch.topk(
                    Softmax(dim=-1)(all_class_logits[-1]), 1)
                tokens[tokens >= self.vocab.get_vocab_size()] = self.unk_idx
                emb = self.target_embedder({'tokens': tokens.squeeze(1)})
            # print(predicted_tokens)
        state['all_class_logits'] = torch.cat(all_class_logits, dim=1)
        state['all_coverages'] = torch.cat(all_coverages, dim=1)
        state['all_attentions'] = torch.cat(all_attentions, dim=1)
        state.pop('class_logits', None)
        state.pop('coverage', None)
        state.pop('attention', None)
        return state

    def _compute_loss(self, target_tokens: Dict[str, torch.Tensor],
                      state: Dict[str, torch.Tensor]):
        # (B, L, V)
        all_class_logits = state['all_class_logits'].transpose(1,
                                                               2).contiguous()
        attentions = state['all_attentions']
        coverages = state['all_coverages']
        tokens = target_tokens['tokens'][:, 1:]
        batch_size = tokens.size(0)
        dim = all_class_logits.size(2) - 1
        pad_tokens = all_class_logits.new_full((all_class_logits.size(0), dim),
                                               fill_value=self.padding_idx,
                                               dtype=torch.long)
        pad_tokens[:, :tokens.size(1)] = tokens

        # (B, L, 1)
        loss = self.criterion(all_class_logits[:, :, :-1], pad_tokens)
        coverage_loss = torch.min(attentions, coverages).sum() / batch_size
        total_loss = loss + coverage_loss
        return total_loss
Beispiel #25
0
class BeamSearchTest(AllenNlpTestCase):
    def setup_method(self):
        super().setup_method()
        self.end_index = transition_probabilities.size()[0] - 1
        self.beam_search = BeamSearch(self.end_index,
                                      max_steps=10,
                                      beam_size=3)

        # This is what the top k should look like for each item in the batch.
        self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5],
                                        [3, 4, 5, 5, 5]])

        # This is what the log probs should look like for each item in the batch.
        self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))

    def _check_results(
        self,
        batch_size: int = 5,
        expected_top_k: np.array = None,
        expected_log_probs: np.array = None,
        beam_search: BeamSearch = None,
        state: Dict[str, torch.Tensor] = None,
        take_step=take_step_with_timestep,
    ) -> None:
        expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
        expected_log_probs = (expected_log_probs if expected_log_probs
                              is not None else self.expected_log_probs)
        state = state or {}

        beam_search = beam_search or self.beam_search
        beam_size = beam_search.beam_size

        initial_predictions = torch.tensor([0] * batch_size)
        top_k, log_probs = beam_search.search(initial_predictions, state,
                                              take_step)  # type: ignore

        # top_k should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
        np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
        np.testing.assert_allclose(log_probs[0].numpy(),
                                   expected_log_probs,
                                   rtol=1e-6)

    @pytest.mark.parametrize("step_function",
                             [take_step_with_timestep, take_step_no_timestep])
    def test_search(self, step_function):
        self._check_results(take_step=step_function)

    def test_finished_state(self):
        state = {}
        state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1],
                                     [1, 1, 1], [0, 0, 0]])
        # shape: (batch_size, 3)

        expected_finished_state = {}
        expected_finished_state["foo"] = np.array([
            [1, 0, 1],
            [1, 0, 1],
            [1, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [1, 1, 1],
            [1, 1, 1],
            [1, 1, 1],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ])
        # shape: (batch_size x beam_size, 3)

        self._check_results(state=state)

        # check finished state.
        for key, array in expected_finished_state.items():
            np.testing.assert_allclose(state[key].numpy(), array)

    def test_diff_shape_state(self):
        state = {}
        state["decoder_hidden"] = torch.tensor([[1, 0, 1], [2, 0,
                                                            1], [0, 0, 1],
                                                [1, 1, 1], [0, 0, 0]])
        state["decoder_hidden"] = state["decoder_hidden"].unsqueeze(0).repeat(
            2, 1, 1)
        # shape: (2, batch_size, 3)

        seq = [
            [1, 0, 1],
            [1, 0, 1],
            [1, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [1, 1, 1],
            [1, 1, 1],
            [1, 1, 1],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ]
        seq = [seq] * 2
        expected_finished_state = {}
        expected_finished_state["decoder_hidden"] = np.array(seq)
        # shape: (2, batch_size x beam_size, 3)

        self._check_results(state=state)

        # check finished state.
        for key, array in expected_finished_state.items():
            np.testing.assert_allclose(state[key].numpy(), array)

    def test_batch_size_of_one(self):
        self._check_results(batch_size=1)

    def test_greedy_search(self):
        beam_search = BeamSearch(self.end_index, beam_size=1)
        expected_top_k = np.array([[1, 2, 3, 4, 5]])
        expected_log_probs = np.log(np.array([0.4]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            beam_search=beam_search,
        )

    def test_single_step(self):
        self.beam_search.max_steps = 1
        expected_top_k = np.array([[1], [2], [3]])
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
        )

    def test_early_stopping(self):
        """
        Checks case where beam search will reach `max_steps` before finding end tokens.
        """
        beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3)
        expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            beam_search=beam_search,
        )

    def test_take_short_sequence_step(self):
        """
        Tests to ensure the top-k from the short_sequence_transition_probabilities
        transition matrix is expected
        """
        self.beam_search.beam_size = 5
        expected_top_k = np.array([[5, 5, 5, 5, 5], [1, 5, 5, 5, 5],
                                   [1, 2, 5, 5, 5], [1, 2, 3, 5, 5],
                                   [1, 2, 3, 4, 5]])
        expected_log_probs = np.log(
            np.array([0.9, 0.09, 0.009, 0.0009, 0.0001]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=take_short_sequence_step,
        )

    def test_min_steps(self):
        """
        Tests to ensure all output sequences are greater than a specified minimum length.
        It uses the `take_short_sequence_step` step function, which favors shorter sequences.
        See `test_take_short_sequence_step`.
        """
        self.beam_search.beam_size = 1

        # An empty sequence is allowed under this step function
        self.beam_search.min_steps = 0
        expected_top_k = np.array([[5]])
        expected_log_probs = np.log(np.array([0.9]))
        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
            self._check_results(
                expected_top_k=expected_top_k,
                expected_log_probs=expected_log_probs,
                take_step=take_short_sequence_step,
            )

        self.beam_search.min_steps = 1
        expected_top_k = np.array([[1, 5]])
        expected_log_probs = np.log(np.array([0.09]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=take_short_sequence_step,
        )

        self.beam_search.min_steps = 2
        expected_top_k = np.array([[1, 2, 5]])
        expected_log_probs = np.log(np.array([0.009]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=take_short_sequence_step,
        )

        self.beam_search.beam_size = 3
        self.beam_search.min_steps = 2
        expected_top_k = np.array([[1, 2, 5, 5, 5], [1, 2, 3, 5, 5],
                                   [1, 2, 3, 4, 5]])
        expected_log_probs = np.log(np.array([0.009, 0.0009, 0.0001]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=take_short_sequence_step,
        )

    def test_different_per_node_beam_size(self):
        # per_node_beam_size = 1
        beam_search = BeamSearch(self.end_index,
                                 beam_size=3,
                                 per_node_beam_size=1)
        self._check_results(beam_search=beam_search)

        # per_node_beam_size = 2
        beam_search = BeamSearch(self.end_index,
                                 beam_size=3,
                                 per_node_beam_size=2)
        self._check_results(beam_search=beam_search)

    def test_catch_bad_config(self):
        """
        If `per_node_beam_size` (which defaults to `beam_size`) is larger than
        the size of the target vocabulary, `BeamSearch.search` should raise
        a ConfigurationError.
        """
        beam_search = BeamSearch(self.end_index, beam_size=20)
        with pytest.raises(ConfigurationError):
            self._check_results(beam_search=beam_search)

    def test_warn_for_bad_log_probs(self):
        # The only valid next step from the initial predictions is the end index.
        # But with a beam size of 3, the call to `topk` to find the 3 most likely
        # next beams will result in 2 new beams that are invalid, in that have probability of 0.
        # The beam search should warn us of this.
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        with pytest.warns(RuntimeWarning,
                          match="Negligible log probabilities"):
            self.beam_search.search(initial_predictions, {},
                                    take_step_no_timestep)

    def test_empty_sequences(self):
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        beam_search = BeamSearch(self.end_index, beam_size=1)
        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
            predictions, log_probs = beam_search.search(
                initial_predictions, {}, take_step_with_timestep)
        # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(predictions.size()) == [2, 1, 1]
        # log probs hould have shape `(batch_size, beam_size)`.
        assert list(log_probs.size()) == [2, 1]
        assert (predictions == self.end_index).all()
        assert (log_probs == 0).all()

    def test_default_from_params_params(self):
        beam_search = BeamSearch.from_params(
            Params({
                "beam_size": 2,
                "end_index": 7
            }))
        assert beam_search.beam_size == 2
        assert beam_search._end_index == 7

    def test_top_p_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep
        p_sampler = TopPSampler(p=0.8)

        top_p, log_probs = BeamSearch(self.end_index,
                                      beam_size=beam_size,
                                      max_steps=10,
                                      sampler=p_sampler).search(
                                          initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_p.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_p) & (top_p <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]

    @pytest.mark.parametrize("p_val", [-1.0, 1.2, 1.1, float("inf")])
    def test_p_val(self, p_val):
        with pytest.raises(ValueError):
            initial_predictions = torch.tensor([0] * 5)
            take_step = take_step_with_timestep
            beam_size = 3
            p_sampler = TopPSampler(p=p_val, with_replacement=True)

            top_k, log_probs = BeamSearch(self.end_index,
                                          beam_size=beam_size,
                                          max_steps=10,
                                          sampler=p_sampler).search(
                                              initial_predictions, {},
                                              take_step)

    def test_top_k_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep
        k_sampler = TopKSampler(k=5, with_replacement=True)

        top_k, log_probs = BeamSearch(self.end_index,
                                      beam_size=beam_size,
                                      max_steps=10,
                                      sampler=k_sampler).search(
                                          initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_k) & (top_k <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]

    @pytest.mark.parametrize("k_val", [-1, 0])
    def test_k_val(self, k_val):
        with pytest.raises(ValueError):
            initial_predictions = torch.tensor([0] * 5)
            take_step = take_step_with_timestep
            beam_size = 3
            k_sampler = TopKSampler(k=k_val, with_replacement=True)

            top_k, log_probs = BeamSearch(self.end_index,
                                          beam_size=beam_size,
                                          max_steps=10,
                                          sampler=k_sampler).search(
                                              initial_predictions, {},
                                              take_step)

    def test_stochastic_beam_search(self):
        initial_predictions = torch.tensor([0] * 5)
        batch_size = 5
        beam_size = 3
        take_step = take_step_with_timestep

        gumbel_sampler = GumbelSampler()

        top_k, log_probs = BeamSearch(self.end_index,
                                      beam_size=beam_size,
                                      max_steps=10,
                                      sampler=gumbel_sampler).search(
                                          initial_predictions, {}, take_step)

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_k) & (top_k <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]

        # Check to make sure that once the end index is predicted, all subsequent tokens
        # must be the end index. This has been tested on toy examples in which
        for batch in top_k:
            for beam in batch:
                reached_end = False
                for token in beam:
                    if token == self.end_index:
                        reached_end = True
                    if reached_end:
                        assert token == self.end_index

    def test_params_sampling(self):
        beam_search = BeamSearch.from_params(
            Params({
                "sampler": {
                    "type": "top-k",
                    "k": 4,
                },
                "beam_size": 2,
                "end_index": 7,
            }))
        assert beam_search.beam_size == 2
        assert beam_search._end_index == 7
        assert beam_search.sampler is not None

    def test_params_p_sampling(self):
        beam_search = BeamSearch.from_params(
            Params({
                "sampler": {
                    "type": "top-p",
                    "p": 0.8,
                },
                "beam_size": 2,
                "end_index": 7,
            }))
        assert beam_search.beam_size == 2
        assert beam_search._end_index == 7
        assert beam_search.sampler is not None

    def test_multinomial_sampler(self):
        sampler = MultinomialSampler(temperature=0.9)

        probabilities, classes, state = sampler.sample_nodes(
            log_probabilities, 3, {"foo": "bar"})

        assert probabilities.size() == classes.size()
        assert classes.size() == (2, 3)
        assert all([x < 4 for x in classes[0]])
        assert all([x > 1 for x in classes[1]])

    def test_top_k_sampler(self):
        sampler = TopKSampler(k=3, temperature=0.9)

        probabilities, classes, state = sampler.sample_nodes(
            log_probabilities, 3, {"foo": "bar"})

        assert probabilities.size() == classes.size()
        assert classes.size() == (2, 3)

        assert all([x > 0 and x < 4 for x in classes[0]])
        assert all([x > 1 and x < 5 for x in classes[1]])

    def test_top_p_sampler(self):
        sampler = TopPSampler(p=0.8, temperature=0.9)

        probabilities, classes, state = sampler.sample_nodes(
            log_probabilities, 3, {"foo": "bar"})

        assert probabilities.size() == classes.size()
        assert classes.size() == (2, 3)

        assert all([x > 0 and x < 4 for x in classes[0]])
        assert all([x > 1 and x < 5 for x in classes[1]])

        # Make sure the filtered classes include the first class that exceeds p
        sampler = TopPSampler(p=0.7, temperature=1.0)

        probabilities, classes, state = sampler.sample_nodes(
            log_probabilities, 2, {"foo": "bar"})

        assert all([x == 2 or x == 3 or x == 1 for x in classes[0]])
        assert all([x == 2 or x == 3 for x in classes[1]])

    def test_gumbel_sampler(self):
        sampler = GumbelSampler()
        num_classes = len(log_probabilities[0])
        sampler_state = sampler.init_state(log_probabilities,
                                           batch_size=2,
                                           num_classes=num_classes)

        log_probs, indices, state = sampler.sample_beams(
            log_probabilities, 3, sampler_state)

        assert log_probs.size() == indices.size()
        assert indices.size() == (2, 3)

        # Make sure the probabilities are sorted.
        _, sorted_indices = log_probs.sort(dim=-1, descending=True)
        assert (sorted_indices == torch.arange(3).unsqueeze(0)).all()

        assert all([x >= 0 and x < 4 for x in indices[0]])
        assert all([x > 1 and x <= 5 for x in indices[1]])

    def test_length_normalized_sequence_log_prob_scorer(self):
        """
        Tests to ensure the sequences are normalized by the correct values. The end token is
        included in the length. The start token is not.
        """
        self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
        )
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
        length_normalization = np.array([5, 4, 3])
        expected_scores = expected_log_probs / length_normalization
        self._check_results(expected_log_probs=expected_scores)

        # Introduce a length penalty
        length_penalty = 2.0
        self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
            length_penalty=length_penalty)
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
        length_normalization = np.array(
            [5**length_penalty, 4**length_penalty, 3**length_penalty])
        expected_scores = expected_log_probs / length_normalization
        self._check_results(expected_log_probs=expected_scores)

        # Pick a length penalty so extreme that the order of the sequences is reversed
        length_penalty = -2.0
        self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
            length_penalty=length_penalty)
        expected_top_k = np.array([[3, 4, 5, 5, 5], [2, 3, 4, 5, 5],
                                   [1, 2, 3, 4, 5]])
        expected_log_probs = np.log(np.array([0.2, 0.3, 0.4]))
        length_normalization = np.array(
            [3**length_penalty, 4**length_penalty, 5**length_penalty])
        expected_scores = expected_log_probs / length_normalization
        self._check_results(expected_top_k=expected_top_k,
                            expected_log_probs=expected_scores)

        # Here, we set the max_steps = 4. This prevents the first sequence from finishing,
        # so its length does not include the end token, whereas the other sequences do.
        length_penalty = 2.0
        self.beam_search.max_steps = 4
        self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
            length_penalty=length_penalty)
        expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]])
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
        length_normalization = np.array(
            [4**length_penalty, 4**length_penalty, 3**length_penalty])
        expected_scores = expected_log_probs / length_normalization
        self._check_results(expected_top_k=expected_top_k,
                            expected_log_probs=expected_scores)

    def test_repeated_ngram_blocking_constraint_init_state(self):
        ngram_size = 3
        batch_size = 2
        constraint = RepeatedNGramBlockingConstraint(ngram_size)

        state = constraint.init_state(batch_size)
        assert len(state) == batch_size
        for beam_states in state:
            assert len(beam_states) == 1
            beam_state = beam_states[0]
            assert len(beam_state.keys()) == 2
            assert len(beam_state["current_prefix"]) == 0
            assert len(beam_state["seen_ngrams"]) == 0

    def test_repeated_ngram_blocking_constraint_apply(self):
        ngram_size = 3
        batch_size = 2
        beam_size = 2
        num_classes = 10
        constraint = RepeatedNGramBlockingConstraint(ngram_size)

        state = [
            [
                {
                    "current_prefix": [0, 1],
                    "seen_ngrams": {}
                },
                {
                    "current_prefix": [2, 3],
                    "seen_ngrams": {
                        (2, 3): [4]
                    }
                },
            ],
            [
                {
                    "current_prefix": [4, 5],
                    "seen_ngrams": {
                        (8, 9): []
                    }
                },
                {
                    "current_prefix": [6, 7],
                    "seen_ngrams": {
                        (6, 7): [0, 1, 2]
                    }
                },
            ],
        ]
        log_probabilities = torch.rand(batch_size, beam_size, num_classes)
        constraint.apply(state, log_probabilities)

        disallowed_locations = torch.nonzero(
            log_probabilities == min_value_of_dtype(
                log_probabilities.dtype)).tolist()
        assert len(disallowed_locations) == 4
        assert [0, 1, 4] in disallowed_locations
        assert [1, 1, 0] in disallowed_locations
        assert [1, 1, 1] in disallowed_locations
        assert [1, 1, 2] in disallowed_locations

    def test_repeated_ngram_blocking_constraint_update_state(self):
        ngram_size = 3
        constraint = RepeatedNGramBlockingConstraint(ngram_size)

        # We will have [2, 3] -> {5, 6} from batch index 0 and [4, 5] -> {0} and [6, 7] -> {3}
        # from batch index
        state = [
            [
                {
                    "current_prefix": [0, 1],
                    "seen_ngrams": {}
                },
                {
                    "current_prefix": [2, 3],
                    "seen_ngrams": {
                        (2, 3): [4]
                    }
                },
            ],
            [
                {
                    "current_prefix": [4, 5],
                    "seen_ngrams": {
                        (8, 9): []
                    }
                },
                {
                    "current_prefix": [6, 7],
                    "seen_ngrams": {
                        (6, 7): [0, 1, 2]
                    }
                },
            ],
        ]
        predictions = torch.LongTensor([[5, 6], [0, 3]])
        backpointers = torch.LongTensor([[1, 1], [0, 1]])

        expected_state = [
            [
                {
                    "current_prefix": [3, 5],
                    "seen_ngrams": {
                        (2, 3): [4, 5]
                    }
                },
                {
                    "current_prefix": [3, 6],
                    "seen_ngrams": {
                        (2, 3): [4, 6]
                    }
                },
            ],
            [
                {
                    "current_prefix": [5, 0],
                    "seen_ngrams": {
                        (8, 9): [],
                        (4, 5): [0]
                    }
                },
                {
                    "current_prefix": [7, 3],
                    "seen_ngrams": {
                        (6, 7): [0, 1, 2, 3]
                    }
                },
            ],
        ]
        updated_state = constraint.update_state(state, predictions,
                                                backpointers)
        assert updated_state == expected_state

    def test_take_repeated_ngram_step(self):
        """
        Tests to ensure the top-k from the `repeated_ngram_transition_probabilities_0`
        transition matrix is expected. The transitions are:

            - p(1|start) = 1.0
            - p(2|1) = 0.4
            - p(3|1) = 0.6
            - p(end|1) = 1e-9
            - p(3|2) = 1.0
            - p(end|2) = 1e-9
            - p(1|3) = 1.0
            - p(end|3) = 1e-9

        The probabilities don't add up 1 because of the 1e-9 transitions to end. That doesn't
        really matter. Each state just needed some transition to the end probability with a very
        small probability to ensure it's possible to reach the end state from there and that it
        isn't selected by beam search without a constraint.

        Below is the beam search tracing for beam size 2. Any sequence below the
        line is not selected by beam search. The number that comes before the sequence
        is the probability of the sequence.

        Step 1
        1.0: [1]

        Step 2
        0.6: [1, 3]
        0.4: [1, 2]
        -----
        1e-9: [1, 2, end]

        Step 3
        0.6: [1, 3, 1]
        0.4: [1, 2, 3]
        -----
        0.6 * 1e-9: [1, 3, end]
        0.4 * 1e-9: [1, 2, end]

        Step 4
        0.4:  [1, 2, 3, 1]
        0.36: [1, 3, 1, 3]
        -----
        0.24:       [1, 3, 1, 2]
        0.6 * 1e-9: [1, 3, 1, end]
        0.4 * 1e-9: [1, 2, 3, end]

        Step 5
        0.36: [1, 3, 1, 3, 1]
        0.24: [1, 2, 3, 1, 3]
        -----
        0.16:        [1, 2, 3, 1, 2]
        0.4 * 1e-9:  [1, 2, 3, 1, end]
        0.36 * 1e-9: [1, 3, 1, 3, end]
        """
        step_function = get_step_function(
            repeated_ngram_transition_probabilities_0)
        self.beam_search.beam_size = 2
        self.beam_search.max_steps = 5
        expected_top_k = np.array([[1, 3, 1, 3, 1], [1, 2, 3, 1, 3]])
        expected_log_probs = np.log(np.array([0.36, 0.24]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=step_function,
        )

    def test_repeated_ngram_blocking_end_to_end_unigrams(self):
        step_function = get_step_function(
            repeated_ngram_transition_probabilities_0)
        self.beam_search.beam_size = 2

        # Unigrams: On step 3, [1, 3, 1] will be blocked and [1, 3, end] will take its place
        self.beam_search.max_steps = 3
        self.beam_search.constraints = [
            RepeatedNGramBlockingConstraint(ngram_size=1)
        ]
        expected_top_k = np.array([[1, 2, 3], [1, 3, 5]])
        expected_log_probs = np.log(np.array([0.4, 0.6 * 1e-9]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=step_function,
        )

        step_function = get_step_function(
            repeated_ngram_transition_probabilities_1)
        self.beam_search.max_steps = 5
        expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]])
        expected_log_probs = np.log(
            np.array(
                [0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=step_function,
        )

    def test_repeated_ngram_blocking_end_to_end_bigrams(self):
        step_function = get_step_function(
            repeated_ngram_transition_probabilities_0)
        self.beam_search.beam_size = 2

        # Bigrams: On step 4, [1, 3, 1, 3] will be blocked and [1, 3, 1, 2] will take its place
        self.beam_search.max_steps = 4
        self.beam_search.constraints = [
            RepeatedNGramBlockingConstraint(ngram_size=2)
        ]
        expected_top_k = np.array([[1, 2, 3, 1], [1, 3, 1, 2]])
        expected_log_probs = np.log(np.array([0.4, 0.24]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=step_function,
        )

    def test_repeated_ngram_blocking_end_to_end_trigrams(self):
        step_function = get_step_function(
            repeated_ngram_transition_probabilities_0)
        self.beam_search.beam_size = 2

        # Trigrams: On step 5, [1, 3, 1, 3, 1] will be blocked and [1, 2, 3, 1, 2] will take its place
        self.beam_search.max_steps = 5
        self.beam_search.constraints = [
            RepeatedNGramBlockingConstraint(ngram_size=3)
        ]
        expected_top_k = np.array([[1, 2, 3, 1, 3], [1, 2, 3, 1, 2]])
        expected_log_probs = np.log(np.array([0.24, 0.16]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=step_function,
        )

    def test_repeated_ngram_blocking_end_indices(self):
        """
        Ensures that the ngram blocking does not mess up when one sequence is shorter
        than another, which would result in repeated "end" symbols.
        """
        # We block unigrams, but 5 (the end symbol) is repeated and it does not mess
        # up the sequence's probability
        step_function = get_step_function(
            repeated_ngram_transition_probabilities_0)
        self.beam_search.beam_size = 2
        self.beam_search.constraints = [
            RepeatedNGramBlockingConstraint(ngram_size=1)
        ]
        expected_top_k = np.array([[1, 3, 5, 5], [1, 2, 3, 5]])
        expected_log_probs = np.log(np.array([0.6 * 1e-9, 0.4 * 1e-9]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            take_step=step_function,
        )
class AssociativeSeq2SeqHiddenDiff(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 target_embedder: TextFieldEmbedder,
                 source_encoder: Seq2VecEncoder,
                 target_encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention = None,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True) -> None:
        super(AssociativeSeq2SeqHiddenDiff, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            })
        else:
            self._bleu = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._source_encoder = source_encoder
        self._target_encoder = target_encoder

        self._encoder_output_dim = self._target_encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim
        target_embedding_dim = source_embedder.get_output_dim()

        if attention:
            self._attention = attention
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim

        else:
            self._attention = None
            self._decoder_input_dim = target_embedding_dim + self._source_encoder.get_output_dim(
            )

        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        self._target_embedder = target_embedder

        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)
        self._output_projection_layer = Linear(self._decoder_output_dim,
                                               num_classes)

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Returns
        -------
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

        Notes
        -----
            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    @overrides
    def forward(
        self,  # type: ignore
        ref_source_tokens: Dict[str, torch.LongTensor],
        instance_source_tokens: Dict[str, torch.LongTensor],
        ref_target_tokens: Dict[str, torch.LongTensor],
        instance_target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:

        state = self._encode(ref_target_tokens, ref_source_tokens,
                             instance_source_tokens)
        if instance_target_tokens:
            state = self._init_decoder_state(state)
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, instance_target_tokens)
        else:
            output_dict = {}

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            if instance_target_tokens and self._bleu:
                # shape: (batch_size, beam_size, max_sequence_length)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_predicted_sequence_length)
                best_predictions = top_k_predictions[:, 0, :]
                self._bleu(best_predictions, instance_target_tokens["tokens"])

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.

        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
        self, ref_target_tokens: Dict[str, torch.Tensor],
        ref_source_tokens: Dict[str, torch.Tensor],
        instance_source_tokens: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_ref_target = self._target_embedder(ref_target_tokens)
        ref_target_mask = util.get_text_field_mask(ref_target_tokens)
        encoded_ref_target = self._target_encoder(embedded_ref_target,
                                                  ref_target_mask)

        embedded_ref_source = self._source_embedder(ref_source_tokens)
        ref_source_mask = util.get_text_field_mask(ref_source_tokens)
        encoded_ref_source = self._source_encoder(embedded_ref_source,
                                                  ref_source_mask)

        embedded_instance_source = self._source_embedder(
            instance_source_tokens)
        instance_source_mask = util.get_text_field_mask(instance_source_tokens)
        encoded_instance_source = self._source_encoder(
            embedded_instance_source, instance_source_mask)

        instance_ref_diff = encoded_ref_source - encoded_instance_source

        #print('mask',instance_source_mask.shape)
        #print('out',target_vectors.shape)
        return {
            "source_mask": ref_target_mask,
            "encoder_outputs": encoded_ref_target,
            "instance_ref_diff": instance_ref_diff
        }

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"], state["source_mask"],
            self._target_encoder.is_bidirectional())
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(
            batch_size, self._decoder_output_dim)
        return state

    def _forward_loop(
        self,
        state: Dict[str, torch.Tensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Make forward pass during training or do greedy search during prediction.

        Notes
        -----
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
        else:
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),
                                                fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
            else:
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(
                input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes

            step_predictions.append(last_predictions.unsqueeze(1))

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits = torch.cat(step_logits, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder._token_embedders['tokens'](
            last_predictions)

        if self._attention:
            # shape: (group_size, encoder_output_dim)
            attended_input = self._prepare_attended_input(
                decoder_hidden, encoder_outputs, source_mask)

            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            decoder_input = torch.cat((attended_input, embedded_input), -1)
        else:
            # shape: (group_size, target_embedding_dim)
            decoder_input = torch.cat(
                (embedded_input, state['instance_ref_diff']), -1)

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input, (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context
        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_hidden)

        return output_projections, state

    def _prepare_attended_input(
            self,
            decoder_hidden_state: torch.LongTensor = None,
            encoder_outputs: torch.LongTensor = None,
            encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(decoder_hidden_state, encoder_outputs,
                                        encoder_outputs_mask)
        #print(torch.argmax(input_weights))
        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input

    @staticmethod
    def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        """
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not self.training:
            all_metrics.update(self._bleu.get_metric(reset=reset))
        return all_metrics
Beispiel #27
0
class BeamSearchTest(AllenNlpTestCase):
    def setup_method(self):
        super().setup_method()
        self.end_index = transition_probabilities.size()[0] - 1
        self.beam_search = BeamSearch(self.end_index,
                                      max_steps=10,
                                      beam_size=3)

        # This is what the top k should look like for each item in the batch.
        self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5],
                                        [3, 4, 5, 5, 5]])

        # This is what the log probs should look like for each item in the batch.
        self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))

    def _check_results(
        self,
        batch_size: int = 5,
        expected_top_k: np.array = None,
        expected_log_probs: np.array = None,
        beam_search: BeamSearch = None,
        state: Dict[str, torch.Tensor] = None,
        take_step=take_step_with_timestep,
    ) -> None:
        expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
        expected_log_probs = (expected_log_probs if expected_log_probs
                              is not None else self.expected_log_probs)
        state = state or {}

        beam_search = beam_search or self.beam_search
        beam_size = beam_search.beam_size

        initial_predictions = torch.tensor([0] * batch_size)
        top_k, log_probs = beam_search.search(initial_predictions, state,
                                              take_step)  # type: ignore

        # top_k should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
        np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
        np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs)

    @pytest.mark.parametrize("step_function",
                             [take_step_with_timestep, take_step_no_timestep])
    def test_search(self, step_function):
        self._check_results(take_step=step_function)

    def test_finished_state(self):
        state = {}
        state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1],
                                     [1, 1, 1], [0, 0, 0]])
        # shape: (batch_size, 3)

        expected_finished_state = {}
        expected_finished_state["foo"] = np.array([
            [1, 0, 1],
            [1, 0, 1],
            [1, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [1, 1, 1],
            [1, 1, 1],
            [1, 1, 1],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ])
        # shape: (batch_size x beam_size, 3)

        self._check_results(state=state)

        # check finished state.
        for key, array in expected_finished_state.items():
            np.testing.assert_allclose(state[key].numpy(), array)

    def test_diff_shape_state(self):
        state = {}
        state["decoder_hidden"] = torch.tensor([[1, 0, 1], [2, 0,
                                                            1], [0, 0, 1],
                                                [1, 1, 1], [0, 0, 0]])
        state["decoder_hidden"] = state["decoder_hidden"].unsqueeze(0).repeat(
            2, 1, 1)
        # shape: (2, batch_size, 3)

        seq = [
            [1, 0, 1],
            [1, 0, 1],
            [1, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [1, 1, 1],
            [1, 1, 1],
            [1, 1, 1],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ]
        seq = [seq] * 2
        expected_finished_state = {}
        expected_finished_state["decoder_hidden"] = np.array(seq)
        # shape: (2, batch_size x beam_size, 3)

        self._check_results(state=state)

        # check finished state.
        for key, array in expected_finished_state.items():
            np.testing.assert_allclose(state[key].numpy(), array)

    def test_batch_size_of_one(self):
        self._check_results(batch_size=1)

    def test_greedy_search(self):
        beam_search = BeamSearch(self.end_index, beam_size=1)
        expected_top_k = np.array([[1, 2, 3, 4, 5]])
        expected_log_probs = np.log(np.array([0.4]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            beam_search=beam_search,
        )

    def test_early_stopping(self):
        """
        Checks case where beam search will reach `max_steps` before finding end tokens.
        """
        beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3)
        expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            beam_search=beam_search,
        )

    def test_different_per_node_beam_size(self):
        # per_node_beam_size = 1
        beam_search = BeamSearch(self.end_index,
                                 beam_size=3,
                                 per_node_beam_size=1)
        self._check_results(beam_search=beam_search)

        # per_node_beam_size = 2
        beam_search = BeamSearch(self.end_index,
                                 beam_size=3,
                                 per_node_beam_size=2)
        self._check_results(beam_search=beam_search)

    def test_catch_bad_config(self):
        """
        If `per_node_beam_size` (which defaults to `beam_size`) is larger than
        the size of the target vocabulary, `BeamSearch.search` should raise
        a ConfigurationError.
        """
        beam_search = BeamSearch(self.end_index, beam_size=20)
        with pytest.raises(ConfigurationError):
            self._check_results(beam_search=beam_search)

    def test_warn_for_bad_log_probs(self):
        # The only valid next step from the initial predictions is the end index.
        # But with a beam size of 3, the call to `topk` to find the 3 most likely
        # next beams will result in 2 new beams that are invalid, in that have probability of 0.
        # The beam search should warn us of this.
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        with pytest.warns(RuntimeWarning, match="Infinite log probabilities"):
            self.beam_search.search(initial_predictions, {},
                                    take_step_no_timestep)

    def test_empty_sequences(self):
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        beam_search = BeamSearch(self.end_index, beam_size=1)
        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
            predictions, log_probs = beam_search.search(
                initial_predictions, {}, take_step_with_timestep)
        # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(predictions.size()) == [2, 1, 1]
        # log probs hould have shape `(batch_size, beam_size)`.
        assert list(log_probs.size()) == [2, 1]
        assert (predictions == self.end_index).all()
        assert (log_probs == 0).all()

    def test_top_p_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep

        top_p, log_probs = BeamSearch.top_p_sampling(
            self.end_index,
            beam_size=beam_size).search(initial_predictions, {}, take_step)

        # bem_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size = 1)
        # top_p, log_probs = beam_search.search(initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_p.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_p) & (top_p <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]

    def test_top_k_search(self):
        initial_predictions = torch.tensor([0] * 5)
        beam_size = 3
        take_step = take_step_with_timestep

        top_k, log_probs = BeamSearch.top_k_sampling(
            self.end_index, k=1,
            beam_size=beam_size).search(initial_predictions, {}, take_step)

        beam_size = beam_size or 1
        batch_size = 5

        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]

        assert ((0 <= top_k) & (top_k <= 5)).all()

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]

    def test_empty_p(self):
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        take_step = take_step_with_timestep

        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
            predictions, log_probs = BeamSearch.top_p_sampling(
                self.end_index, beam_size=1).search(initial_predictions, {},
                                                    take_step)
        # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(predictions.size()) == [2, 1, 1]
        # log probs hould have shape `(batch_size, beam_size)`.
        assert list(log_probs.size()) == [2, 1]
        assert (predictions == self.end_index).all()
        assert (log_probs == 0).all()

    def test_empty_k(self):
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        take_step = take_step_with_timestep

        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
            predictions, log_probs = BeamSearch.top_k_sampling(
                self.end_index, beam_size=1).search(initial_predictions, {},
                                                    take_step)
        # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(predictions.size()) == [2, 1, 1]
        # log probs hould have shape `(batch_size, beam_size)`.
        assert list(log_probs.size()) == [2, 1]
        assert (predictions == self.end_index).all()
        assert (log_probs == 0).all()

    @pytest.mark.parametrize(
        "k",
        [-1.0, 1.2, 1.1, "foo", float("inf")],
    )
    def test_k_val(self, k):
        with pytest.raises(ConfigurationError):
            initial_predictions = torch.tensor([0] * 5)
            take_step = take_step_with_timestep
            beam_size = 3
            top_k, log_probs = BeamSearch.top_k_sampling(
                self.end_index, k=k,
                beam_size=beam_size).search(initial_predictions, {}, take_step)

    @pytest.mark.parametrize(
        "p",
        [-1.0, 1.1, 2, "foo", float("inf")],
    )
    def test_p_val(self, p):
        with pytest.raises(ConfigurationError):
            initial_predictions = torch.tensor([0] * 5)
            take_step = take_step_with_timestep
            beam_size = 3
            top_p, log_probs = BeamSearch.top_p_sampling(
                self.end_index, p=p,
                beam_size=beam_size).search(initial_predictions, {}, take_step)

    def test_params_no_sampling(self):
        beam_search = BeamSearch.from_params(
            Params({
                "beam_size": 2,
                "end_index": 7
            }))
        assert beam_search.beam_size == 2
        assert beam_search._end_index == 7
        assert beam_search.sampler is None

    def test_params_k_sampling(self):
        beam_search = BeamSearch.from_params(
            Params({
                "type": "top_k_sampling",
                "beam_size": 2,
                "end_index": 7,
                "k": 5,
            }))
        assert beam_search.beam_size == 2
        assert beam_search._end_index == 7
        assert beam_search.sampler is not None

    def test_params_p_sampling(self):
        beam_search = BeamSearch.from_params(
            Params({
                "type": "top_p_sampling",
                "beam_size": 2,
                "end_index": 7,
                "p": 0.4,
            }))
        assert beam_search.beam_size == 2
        assert beam_search._end_index == 7
        assert beam_search.sampler is not None
class ImageCaptioning(Model):
    def __init__(self, vocab: Vocabulary, max_timesteps: int = 50, encoder_size: int = 14, encoder_dim: int = 512, 
                 embedding_dim: int = 64, attention_dim: int = 64, decoder_dim: int = 64, beam_size: int = 3, teacher_forcing: bool = True) -> None:
        super().__init__(vocab)
        
        self._max_timesteps = max_timesteps
        
        self._vocab_size = self.vocab.get_vocab_size()
        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        # POSSIBLE CHANGE LATER
        self._pad_index = self.vocab.get_token_index('@@PADDING@@')
        
        self._encoder_size = encoder_size
        self._encoder_dim = encoder_dim
        self._embedding_dim = embedding_dim
        self._attention_dim = attention_dim
        self._decoder_dim = decoder_dim
        
        self._beam_size = beam_size
        self._teacher_forcing = teacher_forcing

        self._init_h = nn.Linear(self._encoder_dim, self._decoder_dim)
        self._init_c = nn.Linear(self._encoder_dim, self._decoder_dim)
        
        self._resnet = torchvision.models.resnet18()
        modules = list(self._resnet.children())[:-2]
        self._encoder = nn.Sequential(
            *modules,
            nn.AdaptiveAvgPool2d(self._encoder_size)
        )

        self._decoder = ImageCaptioningDecoder(self._vocab_size, self._encoder_dim, self._embedding_dim, self._attention_dim, self._decoder_dim)
        
        self.beam_search = BeamSearch(self._end_index, self._max_timesteps, self._beam_size)
        
        self._bleu = BLEU(exclude_indices={self._start_index, self._end_index, self._pad_index})
        self._exprate = Exprate(self._end_index)

    def _init_hidden(self, encoder: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mean_encoder = encoder.mean(dim=1)
        
        # Shape: (batch_size, decoder_dim)
        initial_h = self._init_h(mean_encoder)
        # Shape: (batch_size, decoder_dim)
        initial_c = self._init_c(mean_encoder)

        return initial_h, initial_c
    
    def _decode(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Get data from state
        x = state['x']
        h = state['h']
        c = state['c']
        label = state['label']
        mask = state['mask']
        
        # Get actual size of current batch
        local_batch_size = x.shape[0]

        # Sort data to be able to only compute relevent parts of the batch at each timestep
        # Shape: (batch_size)
        lengths = mask.sum(dim=1)
        # Shape: (batch_size) (batch_size)
        sorted_lengths, indices = lengths.sort(dim=0, descending=True)
        # Computing last timestep isn't necessary with labels since last timestep is eos token or pad token 
        timesteps = sorted_lengths[0] - 1

        # Shape: (batch_size, height * width, encoder_dim)
        # Shape: (batch_size, decoder_dim)
        # Shape: (batch_size, decoder_dim)
        # Shape: (batch_size, timesteps)
        # Shape: (batch_size, timesteps)
        x = x[indices]
        h = h[indices]
        c = c[indices]
        label = label[indices]        
        mask = mask[indices]
        
        # Shape: (batch_size, 1)
        predicted_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1, 1)
        
        # Shape: (batch_size, timesteps, vocab_size)
        predictions = torch.zeros(local_batch_size, timesteps, self._vocab_size, device=device)
        attention_weights = torch.zeros(local_batch_size, timesteps, self._encoder_size * self._encoder_size, device=device)
        
        for t in range(timesteps):
            # Shape: (batch_offset)
            batch_offset = sum([l > t for l in sorted_lengths.tolist()])

            # Only compute data in valid timesteps
            # Shape: (batch_offset, height * width, encoder_dim)
            # Shape: (batch_offset, decoder_dim)
            # Shape: (batch_offset, decoder_dim)
            # Shape: (batch_offset, 1)
            x_t = x[:batch_offset]
            h_t = h[:batch_offset]
            c_t = c[:batch_offset]
            predicted_indices_t = predicted_indices[:batch_offset]
            
            # Decode timestep
            # Shape: (batch_size, decoder_dim) (batch_size, decoder_dim) (batch_size, vocab_size), (batch_size, encoder_dim, 1)
            h, c, preds, attention_weight = self._decoder(x_t, h_t, c_t, predicted_indices_t)
            
            # Get new predicted indices to pass into model at next timestep
            # Use teacher forcing if chosen
            if self._teacher_forcing:
                # Send next timestep's label to next timestep
                # Shape: (batch_size, 1)
                predicted_indices = label[:batch_offset, t + 1].view(-1, 1)
            else:
                # Shape: (batch_size, 1)
                predicted_indices = torch.argmax(preds, dim=1).view(-1, 1)
            
            # Save preds
            predictions[:batch_offset, t, :] = preds
            attention_weights[:batch_offset, t, :] = attention_weight.view(-1, self._encoder_size * self._encoder_size)
            
        # Update state and add logits
        state['x'] = x
        state['h'] = h
        state['c'] = c
        state['label'] = label
        state['mask'] = mask
        state['attention_weights'] = attention_weights
        state['logits'] = predictions
            
        return state
    
    def _beam_search_step(self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # Group_size is batch_size * beam_size except for first decoding timestep where it is batch_size
        # Shape: (group_size, decoder_dim) (group_size, decoder_dim) (group_size, vocab_size)
        h, c, predictions, _ = self._decoder(state['x'], state['h'], state['c'], last_predictions)

        # Update state
        # Shape: (group_size, decoder_dim)
        state['h'] = h
        # Shape: (group_size, decoder_dim)
        state['c'] = c
        
        # Run log_softmax over logit predictions
        # Shape: (group_size, vocab_size)
        log_preds = F.log_softmax(predictions, dim=1)

        return log_preds, state
    
    def _beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Get data from state
        x = state['x']
        h = state['h']
        c = state['c']
        
        # Get actual size of current batch
        local_batch_size = x.shape[0]

        # Beam search wants initial preds of shape: (batch_size)
        # Shape: (batch_size)
        initial_indices = torch.LongTensor([[self._start_index]] * local_batch_size).to(device).view(-1)
        
        state = {'x': x, 'h': h, 'c': c}
        
        # Timesteps returned aren't necessarily max_timesteps
        # Shape: (batch_size, beam_size, timesteps), (batch_size, beam_size)
        predictions, log_probabilities = self.beam_search.search(initial_indices, state, self._beam_search_step)
        
        # Only keep best predictions from beam search
        # Shape: (batch_size, timesteps)
        predictions = predictions[:, 0, :].view(local_batch_size, -1)
        
        return predictions
        
    @overrides
    def forward(self, img: torch.Tensor, label: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        # Encode the image
        # Shape: (batch_size, encoder_dim, height, width)
        x = self._encoder(img)
        
        # Flatten image
        # Shape: (batch_size, height * width, encoder_dim)
        x = x.view(x.shape[0], -1, x.shape[1])

        state = {'x': x}
        # Compute loss on train and val
        if label is not None:
            # Initialize h and c
            # Shape: (batch_size, decoder_dim)
            state['h'], state['c'] = self._init_hidden(x)
            
            # Convert label dict to tensor since label isn't an input to the model and get mask
            # Shape: (batch_size, timesteps)
            state['mask'] = get_text_field_mask(label).to(device)
            # Shape: (batch_size, timesteps)
            state['label'] = label['tokens']

            # Decode encoded image and get loss on train and val
            state = self._decode(state)

            # Loss shouldn't be computed on start token
            state['mask'] = state['mask'][:, 1:].contiguous()
            state['target'] = state['label'][:, 1:].contiguous()
            
            # Compute cross entropy loss
            state['loss'] = sequence_cross_entropy_with_logits(state['logits'], state['target'], state['mask'])
            # Doubly stochastic regularization
            state['loss'] += ((1 - torch.sum(state['attention_weights'], dim=1)) ** 2).mean()

        # Decode encoded image with beam search on val and test
        if not self.training:
            # (Re)initialize h and c
            state['h'], state['c'] = self._init_hidden(x)
            
            # Run beam search
            state['out'] = self._beam_search(state)
            
            # Compute validation scores
            if 'label' in state:
                self._bleu(state['out'], state['target'])
                self._exprate(state['out'], state['target'])
            
        # Set out to logits while training
        else:
            state['out'] = state['logits']
            
        # Create output_dict
        output_dict = {}
        output_dict['out'] =  state['logits'] if self.training else state['out']
        
        if 'loss' in state:
            output_dict['loss'] = state['loss']

        return output_dict
    
    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}

        # Return Bleu score if possible
        if not self.training:
            metrics.update(self._bleu.get_metric(reset))
            metrics.update(self._exprate.get_metric(reset))
            
        return metrics
        
    def _trim_predictions(self, predictions: torch.Tensor) -> torch.Tensor:
        for b in range(predictions.shape[0]):
            # Shape: (timesteps)
            predicted_index = predictions[b]
            # Set last predicted index to eos token in case there are no predicted eos tokens
            predicted_index[-1] = self._end_index

            # Get index of first eos token
            # Shape: (timesteps)
            mask = predicted_index == self._end_index
            # Work around for pytorch not having an easy way to get the first non-zero index
            eos_token_idx = list(mask.cpu().numpy()).index(1)
            
            # Set prediction at eos token's timestep to eos token
            predictions[b, eos_token_idx] = self._end_index
            # Replace all timesteps after first eos token with pad token
            predictions[b, eos_token_idx + 1:] = self._pad_index

        return predictions

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # Trim test preds to first eos token
        # Shape: (batch_size, timesteps)
        output_dict['out'] = self._trim_predictions(output_dict['out'])

        return output_dict
Beispiel #29
0
class CopyNet(Model):
    """
    This is an implementation of `CopyNet <https://arxiv.org/pdf/1603.06393>`_.
    CopyNet is a sequence-to-sequence encoder-decoder model with a copying mechanism
    that can copy tokens from the source sentence into the target sentence instead of
    generating all target tokens only from the target vocabulary.

    It is very similar to a typical seq2seq model used in neural machine translation
    tasks, for example, except that in addition to providing a "generation" score at each timestep
    for the tokens in the target vocabulary, it also provides a "copy" score for each
    token that appears in the source sentence. In other words, you can think of CopyNet
    as a seq2seq model with a dynamic target vocabulary that changes based on the tokens
    in the source sentence, allowing it to predict tokens that are out-of-vocabulary (OOV)
    with respect to the actual target vocab.

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    attention : ``Attention``, required
        This is used to get a dynamic summary of encoder outputs at each timestep
        when producing the "generation" scores for the target vocab.
    beam_size : ``int``, required
        Beam width to use for beam search prediction.
    max_decoding_steps : ``int``, required
        Maximum sequence length of target predictions.
    target_embedding_dim : ``int``, optional (default = 30)
        The size of the embeddings for the target vocabulary.
    copy_token : ``str``, optional (default = '@COPY@')
        The token used to indicate that a target token was copied from the source.
        If this token is not already in your target vocabulary, it will be added.
    source_namespace : ``str``, optional (default = 'source_tokens')
        The namespace for the source vocabulary.
    target_namespace : ``str``, optional (default = 'target_tokens')
        The namespace for the target vocabulary.
    metric : ``Metric``, optional (default = BLEU)
        A metrics to track on a validation set. Note that this metric must accept
        three arguments when called: a batched tensor of predicted token indices, a batched
        tensor of gold token indices, and a set of token indices to exclude when
        calculating n-grams (usually should be the start index, end index, and pad index).
    """

    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 beam_size: int,
                 max_decoding_steps: int,
                 target_embedding_dim: int = 30,
                 copy_token: str = "@COPY@",
                 source_namespace: str = "source_tokens",
                 target_namespace: str = "target_tokens",
                 metric: Metric = BLEU()) -> None:
        super(CopyNet, self).__init__(vocab)
        self._metric = metric
        self._source_namespace = source_namespace
        self._target_namespace = target_namespace
        self._src_start_index = self.vocab.get_token_index(START_SYMBOL, self._source_namespace)
        self._src_end_index = self.vocab.get_token_index(END_SYMBOL, self._source_namespace)
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        self._oov_index = self.vocab.get_token_index(self.vocab._oov_token, self._target_namespace)  # pylint: disable=protected-access
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace)  # pylint: disable=protected-access
        self._copy_index = self.vocab.get_token_index(copy_token, self._target_namespace)
        if self._copy_index == self._oov_index:
            raise ConfigurationError(f"Special copy token {copy_token} missing from target vocab namespace. "
                                     f"You can ensure this token is added to the target namespace with the "
                                     f"vocabulary parameter 'tokens_to_add'.")

        self._target_vocab_size = self.vocab.get_vocab_size(self._target_namespace)

        # Encoding modules.
        self._source_embedder = source_embedder
        self._encoder = encoder

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        # We arbitrarily set the decoder's input dimension to be the same as the output dimension.
        self.encoder_output_dim = self._encoder.get_output_dim()
        self.decoder_output_dim = self.encoder_output_dim
        self.decoder_input_dim = self.decoder_output_dim

        target_vocab_size = self.vocab.get_vocab_size(self._target_namespace)

        # The decoder input will be a function of the embedding of the previous predicted token,
        # an attended encoder hidden state called the "attentive read", and another
        # weighted sum of the encoder hidden state called the "selective read".
        # While the weights for the attentive read are calculated by an `Attention` module,
        # the weights for the selective read are simply the predicted probabilities
        # corresponding to each token in the source sentence from the previous timestep.
        self._target_embedder = Embedding(target_vocab_size, target_embedding_dim)
        self._attention = attention
        self._input_projection_layer = Linear(
                target_embedding_dim + self.encoder_output_dim * 2,
                self.decoder_input_dim)

        # We then run the projected decoder input through an LSTM cell to produce
        # the next hidden state.
        self._decoder_cell = LSTMCell(self.decoder_input_dim, self.decoder_output_dim)

        # We create a "generation" score for each token in the target vocab
        # with a linear projection of the decoder hidden state.
        self._output_generation_layer = Linear(self.decoder_output_dim, target_vocab_size)

        # We create a "copying" score for each source token by applying a non-linearity
        # (tanh) to a linear projection of the encoded hidden state for that token,
        # and then taking the dot product of the result with the decoder hidden state.
        self._output_copying_layer = Linear(self.encoder_output_dim, self.decoder_output_dim)

        # At prediction time, we'll use a beam search to find the best target sequence.
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)

    @overrides
    def forward(self,  # type: ignore
                source_tokens: Dict[str, torch.LongTensor],
                source_to_source: torch.Tensor,
                source_to_target: torch.Tensor,
                metadata: List[Dict[str, Any]],
                target_tokens: Dict[str, torch.LongTensor] = None,
                target_to_source: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Make foward pass with decoder logic for producing the entire target sequence.

        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of `TextField.as_array()` applied on the source `TextField`. This will be
            passed through a `TextFieldEmbedder` and then through an encoder.
        source_to_source : ``torch.Tensor``, required
            Tensor containing indicators of which source tokens match each other.
            Has shape: `(batch_size, trimmed_source_length, trimmed_source_length)`.
        source_to_target : ``torch.Tensor``, required
            Tensor containing vocab index of each source token with respect to the
            target vocab namespace. Shape: `(batch_size, trimmed_source_length)`.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
            Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
            target tokens are also represented as a `TextField`.
        target_to_source : ``torch.Tensor``, optional (default = None)
            A sparse tensor of shape `(batch_size, target_sequence_length, source_sentence_length - 2)` that
            indicates which tokens in the source sentence match each token in the target sequence.
            The last dimension is `source_sentence_length - 2` because we exclude the
            START_SYMBOL and END_SYMBOL in the source sentence (the source sentence is guaranteed
            to contain the START_SYMBOL and END_SYMBOL).

        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(source_tokens, source_to_source, source_to_target)

        if target_tokens:
            state = self._init_decoder_state(state)
            output_dict = self._forward_loop(target_tokens, target_to_source, state)
        else:
            output_dict = {}

        output_dict["metadata"] = metadata

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            if self._metric and target_tokens:
                # shape: (batch_size, beam_size, max_sequence_length)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_predicted_sequence_length)
                best_predictions = top_k_predictions[:, 0, :]
                # shape: (batch_size, target_sequence_length)
                gold_tokens = self._gather_extended_gold_tokens(target_tokens["tokens"], target_to_source)
                self._metric(best_predictions, gold_tokens,
                             (self._pad_index, self._end_index, self._start_index))

        return output_dict

    def _gather_extended_gold_tokens(self,
                                     target_tokens: torch.LongTensor,
                                     target_to_source: torch.Tensor) -> torch.LongTensor:
        """
        Modify the gold target tokens relative to the extended vocabulary.

        For gold targets that are OOV but were copied from the source, the OOV index
        will be changed to the index of the first occurence in the source sentence,
        offset by the size of the target vocabulary.

        Parameters
        ----------
        target_tokens : ``torch.LongTensor``
            Shape: `(batch_size, target_sequence_length)`.
        target_to_source : ``torch.Tensor``
            Shape: `(batch_size, target_sequence_length, trimmed_source_length)`.

        Returns
        -------
        torch.Tensor
            Modified `target_tokens` with OOV indices replaced by offset index
            of first match in source sentence.
        """
        # Only change indices for tokens that were OOV in target vocab but copied from source.
        # shape: (batch_size, target_sequence_length)
        oov = (target_tokens == self._oov_index)
        # shape: (batch_size, target_sequence_length)
        copied = (target_to_source.sum(-1) > 0)
        # shape: (batch_size, target_sequence_length)
        mask = (oov & copied).long()
        # shape: (batch_size, target_sequence_length)
        _, first_match = target_to_source.max(-1)
        # shape: (batch_size, target_sequence_length)
        new_target_tokens = target_tokens * (1 - mask) + (first_match.long() + self._target_vocab_size) * mask
        return new_target_tokens

    def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Initialize the encoded state to be passed to the first decoding time step.
        """
        batch_size, _ = state["source_mask"].size()

        # Initialize the decoder hidden state with the final output of the encoder,
        # and the decoder context with zeros.
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
                state["encoder_outputs"],
                state["source_mask"],
                self._encoder.is_bidirectional())
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self.decoder_output_dim)

        return state

    def _encode(self,
                source_tokens: Dict[str, torch.Tensor],
                source_to_source: torch.Tensor,
                source_to_target: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Encode source input sentences.
        """
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)

        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)

        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)

        state = {
                "source_mask": source_mask,
                "encoder_outputs": encoder_outputs,
                "source_to_source": source_to_source,
                "source_to_target": source_to_target,
        }

        return state

    def _decoder_step(self,
                      last_predictions: torch.Tensor,
                      selective_weights: torch.Tensor,
                      state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = state["source_mask"].float()

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        # shape: (batch_size, max_input_sequence_length)
        attentive_weights = self._attention(
                state["decoder_hidden"], state["encoder_outputs"], encoder_outputs_mask)

        # shape: (batch_size, encoder_output_dim)
        attentive_read = util.weighted_sum(state["encoder_outputs"], attentive_weights)

        # shape: (batch_size, encoder_output_dim)
        selective_read = util.weighted_sum(state["encoder_outputs"][:, 1:-1], selective_weights)

        # shape: (group_size, target_embedding_dim + encoder_output_dim * 2)
        decoder_input = torch.cat((embedded_input, attentive_read, selective_read), -1)

        # shape: (group_size, decoder_input_dim)
        projected_decoder_input = self._input_projection_layer(decoder_input)

        state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
                projected_decoder_input,
                (state["decoder_hidden"], state["decoder_context"]))

        return state

    def _get_generation_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:
        return self._output_generation_layer(state["decoder_hidden"])

    def _get_copy_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:
        # shape: (batch_size, max_input_sequence_length - 2, encoder_output_dim)
        trimmed_encoder_outputs = state["encoder_outputs"][:, 1:-1]

        # shape: (batch_size, max_input_sequence_length - 2, decoder_output_dim)
        copy_projection = self._output_copying_layer(trimmed_encoder_outputs)

        # shape: (batch_size, max_input_sequence_length - 2, decoder_output_dim)
        copy_projection = torch.tanh(copy_projection)

        # shape: (batch_size, max_input_sequence_length - 2)
        copy_scores = copy_projection.bmm(state["decoder_hidden"].unsqueeze(-1)).squeeze(-1)

        return copy_scores

    def _get_ll_contrib(self,
                        generation_scores: torch.Tensor,
                        copy_scores: torch.Tensor,
                        target_tokens: torch.Tensor,
                        target_to_source: torch.Tensor,
                        copy_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get the log-likelihood contribution from a single timestep.

        Parameters
        ----------
        generation_scores : ``torch.Tensor``
            Shape: `(batch_size, target_vocab_size)`
        copy_scores : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`
        target_tokens : ``torch.Tensor``
            Shape: `(batch_size,)`
        target_to_source : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`
        copy_mask : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Shape: `(batch_size,), (batch_size, max_input_sequence_length)`
        """
        _, target_size = generation_scores.size()

        # The point of this mask is to just mask out all source token scores
        # that just represent padding. We apply the mask to the concatenation
        # of the generation scores and the copy scores to normalize the scores
        # correctly during the softmax.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        mask = torch.cat((generation_scores.new_full(generation_scores.size(), 1.0), copy_mask), dim=-1)

        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        all_scores = torch.cat((generation_scores, copy_scores), dim=-1)

        # Normalize generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        probs = util.masked_softmax(all_scores, mask)

        # Calculate the probability (normalized copy score) for each token in the source sentence
        # that matches the current target token. We end up summing the scores
        # for each occurence of a matching token to get the actual score, but we also
        # need the un-summed probabilities to create the selective read state
        # during the next time step.
        # shape: (batch_size, trimmed_source_length)
        raw_selective_weights = probs[:, target_size:] * target_to_source.float()
        # shape: (batch_size,)
        sum_selective_weights = raw_selective_weights.sum(-1)
        # shape: (batch_size, trimmed_source_length)
        selective_weights = raw_selective_weights / (sum_selective_weights.unsqueeze(-1) + 1e-13)

        # This mask ensures that item in the batch has a non-zero generation score for this timestep
        # only when the gold target token is not OOV or there are no matching tokens
        # in the source sentence.
        # shape: (batch_size,)
        gen_mask = ((target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)).float()

        # Now we get the generation score for the gold target token.
        # shape: (batch_size,)
        step_likelihood = probs.gather(1, target_tokens.unsqueeze(1)).squeeze(-1) * gen_mask

        # ... and add the copy score.
        # shape: (batch_size,)
        step_likelihood = step_likelihood + sum_selective_weights

        # shape: (batch_size,)
        step_log_likelihood = step_likelihood.log()

        return step_log_likelihood, selective_weights

    def _forward_loop(self,
                      target_tokens: Dict[str, torch.LongTensor],
                      target_to_source: torch.Tensor,
                      state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Calculate the loss against gold targets.
        """
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size, target_sequence_length = target_tokens["tokens"].size()

        # The last input from the target is either padding or the end symbol.
        # Either way, we don't have to process it.
        num_decoding_steps = target_sequence_length - 1

        # We use this to fill in the copy index when the previous input was copied.
        # shape: (batch_size,)
        copy_input_choices = source_mask.new_full((batch_size,), fill_value=self._copy_index)

        # shape: (batch_size, trimmed_source_length)
        copy_mask = source_mask[:, 1:-1].float()

        # We need to keep track of the probabilities assigned to tokens in the source
        # sentence that were copied during the previous timestep, since we use
        # those probabilities as weights when calculating the "selective read".
        # shape: (batch_size, trimmed_source_length)
        selective_weights = state["decoder_hidden"].new_zeros(copy_mask.size())

        step_log_likelihoods = []
        for timestep in range(num_decoding_steps):
            # shape: (batch_size,)
            input_choices = target_tokens["tokens"][:, timestep]

            # If the previous target token was copied, we use the special copy token.
            # But the end target token will always be THE end token, so we know
            # it was not copied.
            if timestep < num_decoding_steps - 1:
                # Get mask tensor indicating which instances were copied.
                # shape: (batch_size,)
                copied = (target_to_source[:, timestep, :].sum(-1) > 0).long()

                # shape: (batch_size,)
                input_choices = input_choices * (1 - copied) + copy_input_choices * copied

            # Update the decoder state by taking a step through the RNN.
            state = self._decoder_step(input_choices, selective_weights, state)

            # Get generation scores for each token in the target vocab.
            # shape: (batch_size, target_vocab_size)
            generation_scores = self._get_generation_scores(state)

            # Get copy scores for each token in the source sentence, excluding the start
            # and end tokens.
            # shape: (batch_size, max_input_sequence_length - 2)
            copy_scores = self._get_copy_scores(state)

            # shape: (batch_size,)
            step_target_tokens = target_tokens["tokens"][:, timestep + 1]

            # shape: (batch_size, max_input_sequence_length - 2)
            step_target_to_source = target_to_source[:, timestep + 1]

            step_log_likelihood, selective_weights = self._get_ll_contrib(
                    generation_scores,
                    copy_scores,
                    step_target_tokens,
                    step_target_to_source,
                    copy_mask)
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))

        # Gather step log-likelihoods.
        # shape: (batch_size, num_decoding_steps = target_sequence_length - 1)
        log_likelihoods = torch.cat(step_log_likelihoods, 1)

        # Get target mask to exclude likelihood contributions from timesteps after
        # the END token.
        # shape: (batch_size, target_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)

        # The first timestep is just the START token, which is not included in the likelihoods.
        # shape: (batch_size, num_decoding_steps)
        target_mask = target_mask[:, 1:].float()

        # Sum of step log-likelihoods.
        # shape: (batch_size,)
        log_likelihood = (log_likelihoods * target_mask).sum(dim=-1)

        # The loss is the negative log-likelihood, averaged over the batch.
        loss = - log_likelihood.sum() / batch_size

        return {"loss": loss}

    def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size, source_length = state["source_mask"].size()
        trimmed_source_length = source_length - 2

        # Initialize the copy scores to zero.
        state["copy_probs"] = state["decoder_hidden"].new_zeros((batch_size, trimmed_source_length))

        # shape: (batch_size,)
        start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
                start_predictions, state, self.take_search_step)

        output_dict = {
                "predicted_log_probs": log_probabilities,
                "predictions": all_top_k_predictions,
        }

        return output_dict

    def _get_input_and_selective_weights(self,
                                         last_predictions: torch.LongTensor,
                                         state: Dict[str, torch.Tensor]) -> Tuple[torch.LongTensor, torch.Tensor]:
        """
        Get input choices for the decoder and the selective copy weights.

        The decoder input choices are simply the `last_predictions`, except for
        target OOV predictions that were copied from source tokens, in which case
        the prediction will be changed to the COPY symbol in the target namespace.

        The selective weights are just the probabilities assigned to source
        tokens that were copied, normalized to sum to 1. If no source tokens were copied,
        there will be all zeros.

        Parameters
        ----------
        last_predictions : ``torch.LongTensor``
            Shape: `(group_size,)`
        state : ``Dict[str, torch.Tensor]``

        Returns
        -------
        Tuple[torch.LongTensor, torch.Tensor]
            `input_choices` (shape `(group_size,)`) and `selective_weights`
            (shape `(group_size, trimmed_source_length)`).
        """
        group_size, trimmed_source_length = state["source_to_target"].size()

        # This is a mask indicating which last predictions were copied from the
        # the source AND not in the target vocabulary (OOV).
        # (group_size,)
        only_copied_mask = (last_predictions >= self._target_vocab_size).long()

        # If the last prediction was in the target vocab or OOV but not copied,
        # we use that as input, otherwise we use the COPY token.
        # shape: (group_size,)
        copy_input_choices = only_copied_mask.new_full((group_size,), fill_value=self._copy_index)
        input_choices = last_predictions * (1 - only_copied_mask) + copy_input_choices * only_copied_mask

        # In order to get the `selective_weights`, we need to find out which predictions
        # were copied or copied AND generated, which is the case when a prediction appears
        # in both the source sentence and the target vocab. But whenever a prediction
        # is in the target vocab (even if it also appeared in the source sentence),
        # its index will be the corresponding target vocab index, not its index in
        # the source sentence offset by the target vocab size. So we first
        # use `state["source_to_target"]` to get an indicator of every source token
        # that matches the predicted target token.
        # shape: (group_size, trimmed_source_length)
        expanded_last_predictions = last_predictions.unsqueeze(-1).expand(group_size, trimmed_source_length)
        # shape: (group_size, trimmed_source_length)
        source_copied_and_generated = (state["source_to_target"] == expanded_last_predictions).long()

        # In order to get indicators for copied source tokens that are OOV with respect
        # to the target vocab, we'll make use of `state["source_to_source"]`.
        # First we adjust predictions relative to the start of the source tokens.
        # This makes sense because predictions for copied tokens are given by the index of the copied
        # token in the source sentence, offset by the size of the target vocabulary.
        # shape: (group_size,)
        adjusted_predictions = last_predictions - self._target_vocab_size
        # The adjusted indices for items that were not copied will be negative numbers,
        # and therefore invalid. So we zero them out.
        adjusted_predictions = adjusted_predictions * only_copied_mask
        # shape: (group_size, trimmed_source_length,  trimmed_source_length)
        source_to_source = state["source_to_source"]
        # Expand adjusted_predictions to match source_to_source shape.
        # shape: (group_size, trimmed_source_length, trimmed_source_length)
        adjusted_predictions = adjusted_predictions.unsqueeze(-1)\
            .unsqueeze(-1)\
            .expand(source_to_source.size())
        # The mask will contain indicators for source tokens that were copied
        # during the last timestep.
        # shape: (group_size, trimmed_source_length)
        source_only_copied = source_to_source.gather(-1, adjusted_predictions)[:, :, 0].long()
        # Since we zero'd-out indices for predictions that were not copied,
        # we need to zero out all entries of this mask corresponding to those predictions.
        source_only_copied = source_only_copied * only_copied_mask.\
            unsqueeze(-1).\
            expand(source_only_copied.size())

        # shape: (group_size, trimmed_source_length)
        mask = source_only_copied | source_copied_and_generated

        # shape: (group_size, trimmed_source_length)
        raw_selective_weights = state["copy_probs"] * mask.float()

        # shape: (group_size, trimmed_source_length)
        selective_weights = raw_selective_weights / (raw_selective_weights.sum(dim=-1, keepdim=True) + 1e-13)

        return input_choices, selective_weights

    def _gather_final_probs(self,
                            generation_probs: torch.Tensor,
                            copy_probs: torch.Tensor,
                            state: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Combine copy probabilities with generation probabilities for matching tokens.

        Parameters
        ----------
        generation_probs : ``torch.Tensor``
            Shape: `(group_size, target_vocab_size)`
        copy_probs : ``torch.Tensor``
            Shape: `(group_size, trimmed_source_length)`
        state : ``Dict[str, torch.Tensor]``

        Returns
        -------
        torch.Tensor
            Shape: `(group_size, target_vocab_size + trimmed_source_length)`.
        """
        _, trimmed_source_length = state["source_to_target"].size()

        # shape: [(batch_size, *)]
        modified_probs_list: List[torch.Tensor] = [generation_probs]
        for i in range(trimmed_source_length):
            # shape: (group_size,)
            copy_probs_slice = copy_probs[:, i]

            # `source_to_target` is a matrix of shape (group_size, trimmed_source_length)
            # where element (i, j) is the vocab index of the target token that matches the jth
            # source token in the ith group, if there is one, or the index of the OOV symbol otherwise.
            # We'll use this to add copy scores to corresponding generation scores.
            # shape: (group_size,)
            source_to_target_slice = state["source_to_target"][:, i]

            # The OOV index in the source_to_target_slice indicates that the source
            # token is not in the target vocab, so we don't want to add that copy score
            # to the OOV token.
            copy_probs_to_add_mask = (source_to_target_slice != self._oov_index).float()
            copy_probs_to_add = copy_probs_slice * copy_probs_to_add_mask
            generation_probs.scatter_add_(
                    -1, source_to_target_slice.unsqueeze(-1), copy_probs_to_add.unsqueeze(-1))

            # We have to combine copy scores for duplicate source tokens so that
            # we can find the overall most likely source token. So, if this is the first
            # occurence of this particular source token, we add the probs from all other
            # occurences, otherwise we zero it out since it was already accounted for.
            if i < (trimmed_source_length - 1):
                # Sum copy scores from future occurences of source token.
                # shape: (group_size, trimmed_source_length - i)
                source_future_occurences = state["source_to_source"][:, i, (i+1):]
                # shape: (group_size, trimmed_source_length - i)
                future_copy_probs = copy_probs[:, (i+1):] * source_future_occurences
                # shape: (group_size,)
                summed_future_copy_probs = future_copy_probs.sum(dim=-1)
                copy_probs_slice = copy_probs_slice + summed_future_copy_probs
            if i > 0:
                # Zero-out copy probs that we have already accounted for.
                # shape: (group_size, i)
                source_previous_occurences = state["source_to_source"][:, i, 0:i]
                # shape: (group_size,)
                duplicate_mask = (source_previous_occurences.sum(dim=-1) == 0).float()
                copy_probs_slice = copy_probs_slice * duplicate_mask

            # Finally, we zero-out copy scores that we added to the generation scores
            # above so that we don't double-count them.
            # shape: (group_size,)
            left_over_copy_probs = copy_probs_slice * (1.0 - copy_probs_to_add_mask)

            modified_probs_list.append(left_over_copy_probs.unsqueeze(-1))

        # shape: (group_size, target_vocab_size + trimmed_source_length)
        modified_probs = torch.cat(modified_probs_list, dim=-1)

        return modified_probs

    def take_search_step(self,
                         last_predictions: torch.Tensor,
                         state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take step during beam search.

        This function is what gets passed to the `BeamSearch.search` method. It takes
        predictions from the last timestep and the current state and outputs
        the log probabilities assigned to tokens for the next timestep, as well as the updated
        state.

        Since we are predicting tokens out of the extended vocab (target vocab + all unique
        tokens from the source sentence), this is a little more complicated that just
        making a forward pass through the model. The output log probs will have
        shape `(group_size, target_vocab_size + trimmed_source_length)` so that each
        token in the target vocab and source sentence are assigned a probability.

        Note that copy scores are assigned to each source token based on their position, not unique value.
        So if a token appears more than once in the source sentence, it will have more than one score.
        Further, if a source token is also part of the target vocab, its final score
        will be the sum of the generation and copy scores. Therefore, in order to
        get the score for all tokens in the extended vocab at this step,
        we have to combine copy scores for re-occuring source tokens and potentially
        add them to the generation scores for the matching token in the target vocab, if
        there is one.

        So we can break down the final log probs output as the concatenation of two
        matrices, A: `(group_size, target_vocab_size)`, and B: `(group_size, trimmed_source_length)`.
        Matrix A contains the sum of the generation score and copy scores (possibly 0)
        for each target token. Matrix B contains left-over copy scores for source tokens
        that do NOT appear in the target vocab, with zeros everywhere else. But since
        a source token may appear more than once in the source sentence, we also have to
        sum the scores for each appearance of each unique source token. So matrix B
        actually only has non-zero values at the first occurence of each source token
        that is not in the target vocab.

        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            Shape: `(group_size,)`

        state : ``Dict[str, torch.Tensor]``
            Contains all state tensors necessary to produce generation and copy scores
            for next step.

        Notes
        -----
        `group_size` != `batch_size`. In fact, `group_size` = `batch_size * beam_size`.
        """
        _, trimmed_source_length = state["source_to_target"].size()

        # Get input to the decoder RNN and the selective weights. `input_choices`
        # is the result of replacing target OOV tokens in `last_predictions` with the
        # copy symbol. `selective_weights` consist of the normalized copy probabilities
        # assigned to the source tokens that were copied. If no tokens were copied,
        # there will be all zeros.
        # shape: (group_size,), (group_size, trimmed_source_length)
        input_choices, selective_weights = self._get_input_and_selective_weights(last_predictions, state)

        # Update the decoder state by taking a step through the RNN.
        state = self._decoder_step(input_choices, selective_weights, state)

        # Get the un-normalized generation scores for each token in the target vocab.
        # shape: (group_size, target_vocab_size)
        generation_scores = self._get_generation_scores(state)

        # Get the un-normalized copy scores for each token in the source sentence,
        # excluding the start and end tokens.
        # shape: (group_size, trimmed_source_length)
        copy_scores = self._get_copy_scores(state)

        # Concat un-normalized generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        all_scores = torch.cat((generation_scores, copy_scores), dim=-1)

        # shape: (group_size, trimmed_source_length)
        copy_mask = state["source_mask"][:, 1:-1].float()

        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        mask = torch.cat((generation_scores.new_full(generation_scores.size(), 1.0), copy_mask), dim=-1)

        # Normalize generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        probs = util.masked_softmax(all_scores, mask)

        # shape: (group_size, target_vocab_size), (group_size, trimmed_source_length)
        generation_probs, copy_probs = probs.split([self._target_vocab_size, trimmed_source_length], dim=-1)

        # Update copy_probs needed for getting the `selective_weights` at the next timestep.
        state["copy_probs"] = copy_probs

        # We now have normalized generation and copy scores, but to produce the final
        # score for each token in the extended vocab, we have to go through and add
        # the copy scores to the generation scores of matching target tokens, and sum
        # the copy scores of duplicate source tokens.
        # shape: (group_size, target_vocab_size + trimmed_source_length)
        final_probs = self._gather_final_probs(generation_probs, copy_probs, state)

        return final_probs.log(), state

    def _get_predicted_tokens(self,
                              predicted_indices: numpy.ndarray,
                              batch_metadata: List[Any],
                              n_best: int = None) -> List[Union[List[List[str]], List[str]]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        predicted_tokens: List[Union[List[List[str]], List[str]]] = []
        for top_k_predictions, metadata in zip(predicted_indices, batch_metadata):
            batch_predicted_tokens: List[List[str]] = []
            for indices in top_k_predictions[:n_best]:
                tokens: List[str] = []
                indices = list(indices)
                if self._end_index in indices:
                    indices = indices[:indices.index(self._end_index)]
                for index in indices:
                    if index >= self._target_vocab_size:
                        adjusted_index = index - self._target_vocab_size
                        token = metadata["source_tokens"][adjusted_index]
                    else:
                        token = self.vocab.get_token_from_index(index, self._target_namespace)
                    tokens.append(token)
                batch_predicted_tokens.append(tokens)
            if n_best == 1:
                predicted_tokens.append(batch_predicted_tokens[0])
            else:
                predicted_tokens.append(batch_predicted_tokens)
        return predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        Finalize predictions.

        After a beam search, the predicted indices correspond to tokens in the target vocabulary
        OR tokens in source sentence. Here we gather the actual tokens corresponding to
        the indices.
        """
        predicted_tokens = self._get_predicted_tokens(output_dict["predictions"],
                                                      output_dict["metadata"])
        output_dict["predicted_tokens"] = predicted_tokens
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._metric and not self.training:
            all_metrics.update(self._metric.get_metric(reset=reset))
        return all_metrics
Beispiel #30
0
class BeamSearchTest(AllenNlpTestCase):
    def setup_method(self):
        super().setup_method()
        self.end_index = transition_probabilities.size()[0] - 1
        self.beam_search = BeamSearch(self.end_index,
                                      max_steps=10,
                                      beam_size=3)

        # This is what the top k should look like for each item in the batch.
        self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5],
                                        [3, 4, 5, 5, 5]])

        # This is what the log probs should look like for each item in the batch.
        self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))

    def _check_results(
        self,
        batch_size: int = 5,
        expected_top_k: np.array = None,
        expected_log_probs: np.array = None,
        beam_search: BeamSearch = None,
        state: Dict[str, torch.Tensor] = None,
    ) -> None:
        expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
        expected_log_probs = (expected_log_probs if expected_log_probs
                              is not None else self.expected_log_probs)
        state = state or {}

        beam_search = beam_search or self.beam_search
        beam_size = beam_search.beam_size

        initial_predictions = torch.tensor([0] * batch_size)
        top_k, log_probs = beam_search.search(initial_predictions, state,
                                              take_step)  # type: ignore

        # top_k should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
        np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)

        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(log_probs.size()) == [batch_size, beam_size]
        np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs)

    def test_search(self):
        self._check_results()

    def test_finished_state(self):
        state = {}
        state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1],
                                     [1, 1, 1], [0, 0, 0]])
        # shape: (batch_size, 3)

        expected_finished_state = {}
        expected_finished_state["foo"] = np.array([
            [1, 0, 1],
            [1, 0, 1],
            [1, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [2, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [0, 0, 1],
            [1, 1, 1],
            [1, 1, 1],
            [1, 1, 1],
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0],
        ])
        # shape: (batch_size x beam_size, 3)

        self._check_results(state=state)

        # check finished state.
        for key, array in expected_finished_state.items():
            np.testing.assert_allclose(state[key].numpy(), array)

    def test_batch_size_of_one(self):
        self._check_results(batch_size=1)

    def test_greedy_search(self):
        beam_search = BeamSearch(self.end_index, beam_size=1)
        expected_top_k = np.array([[1, 2, 3, 4, 5]])
        expected_log_probs = np.log(np.array([0.4]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            beam_search=beam_search,
        )

    def test_early_stopping(self):
        """
        Checks case where beam search will reach `max_steps` before finding end tokens.
        """
        beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3)
        expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
        self._check_results(
            expected_top_k=expected_top_k,
            expected_log_probs=expected_log_probs,
            beam_search=beam_search,
        )

    def test_different_per_node_beam_size(self):
        # per_node_beam_size = 1
        beam_search = BeamSearch(self.end_index,
                                 beam_size=3,
                                 per_node_beam_size=1)
        self._check_results(beam_search=beam_search)

        # per_node_beam_size = 2
        beam_search = BeamSearch(self.end_index,
                                 beam_size=3,
                                 per_node_beam_size=2)
        self._check_results(beam_search=beam_search)

    def test_catch_bad_config(self):
        """
        If `per_node_beam_size` (which defaults to `beam_size`) is larger than
        the size of the target vocabulary, `BeamSearch.search` should raise
        a ConfigurationError.
        """
        beam_search = BeamSearch(self.end_index, beam_size=20)
        with pytest.raises(ConfigurationError):
            self._check_results(beam_search=beam_search)

    def test_warn_for_bad_log_probs(self):
        # The only valid next step from the initial predictions is the end index.
        # But with a beam size of 3, the call to `topk` to find the 3 most likely
        # next beams will result in 2 new beams that are invalid, in that have probability of 0.
        # The beam search should warn us of this.
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        with pytest.warns(RuntimeWarning, match="Infinite log probabilities"):
            self.beam_search.search(initial_predictions, {}, take_step)

    def test_empty_sequences(self):
        initial_predictions = torch.LongTensor(
            [self.end_index - 1, self.end_index - 1])
        beam_search = BeamSearch(self.end_index, beam_size=1)
        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
            predictions, log_probs = beam_search.search(
                initial_predictions, {}, take_step)
        # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
        assert list(predictions.size()) == [2, 1, 1]
        # log probs hould have shape `(batch_size, beam_size)`.
        assert list(log_probs.size()) == [2, 1]
        assert (predictions == self.end_index).all()
        assert (log_probs == 0).all()
Beispiel #31
0
class CopyNetSeq2Seq(Model):
    def __init__(
        self,
        vocab: Vocabulary,
        attention: Attention,
        beam_size: int,
        max_decoding_steps: int,
        target_embedding_dim: int = 30,
        copy_token: str = "@COPY@",
        source_namespace: str = "bert",
        target_namespace: str = "target_tokens",
        tensor_based_metric: Metric = None,
        token_based_metric: Metric = None,
        initializer: InitializerApplicator = InitializerApplicator(),
    ) -> None:
        super().__init__(vocab)
        self._source_namespace = source_namespace
        self._target_namespace = target_namespace
        self._src_start_index = self.vocab.get_token_index(
            START_SYMBOL, self._source_namespace)
        self._src_end_index = self.vocab.get_token_index(
            END_SYMBOL, self._source_namespace)
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._oov_index = self.vocab.get_token_index(self.vocab._oov_token,
                                                     self._target_namespace)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)
        self._copy_index = self.vocab.add_token_to_namespace(
            copy_token, self._target_namespace)

        self._tensor_based_metric = tensor_based_metric or BLEU(
            exclude_indices={
                self._pad_index, self._end_index, self._start_index
            })
        self._token_based_metric = token_based_metric

        self._target_vocab_size = self.vocab.get_vocab_size(
            self._target_namespace)

        # Encoding modules.
        bert_token_embedding = PretrainedBertEmbedder('bert-base-uncased',
                                                      requires_grad=True)

        self._source_embedder = bert_token_embedding
        self._encoder = PassThroughEncoder(
            input_dim=self._source_embedder.get_output_dim())

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        # We arbitrarily set the decoder's input dimension to be the same as the output dimension.
        self.encoder_output_dim = self._encoder.get_output_dim()
        self.decoder_output_dim = self.encoder_output_dim
        self.decoder_input_dim = self.decoder_output_dim

        target_vocab_size = self.vocab.get_vocab_size(self._target_namespace)

        # The decoder input will be a function of the embedding of the previous predicted token,
        # an attended encoder hidden state called the "attentive read", and another
        # weighted sum of the encoder hidden state called the "selective read".
        # While the weights for the attentive read are calculated by an `Attention` module,
        # the weights for the selective read are simply the predicted probabilities
        # corresponding to each token in the source sentence that matches the target
        # token from the previous timestep.
        self._target_embedder = Embedding(target_vocab_size,
                                          target_embedding_dim)
        self._attention = attention
        self._input_projection_layer = Linear(
            target_embedding_dim + self.encoder_output_dim * 2,
            self.decoder_input_dim)

        # We then run the projected decoder input through an LSTM cell to produce
        # the next hidden state.
        self._decoder_cell = LSTMCell(self.decoder_input_dim,
                                      self.decoder_output_dim)

        # We create a "generation" score for each token in the target vocab
        # with a linear projection of the decoder hidden state.
        self._output_generation_layer = Linear(self.decoder_output_dim,
                                               target_vocab_size)

        # We create a "copying" score for each source token by applying a non-linearity
        # (tanh) to a linear projection of the encoded hidden state for that token,
        # and then taking the dot product of the result with the decoder hidden state.
        self._output_copying_layer = Linear(self.encoder_output_dim,
                                            self.decoder_output_dim)

        # At prediction time, we'll use a beam search to find the best target sequence.
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        initializer(self)

    @overrides
    def forward(
        self,  # type: ignore
        source_tokens: Dict[str, torch.LongTensor],
        source_token_ids: torch.Tensor,
        source_to_target: torch.Tensor,
        metadata: List[Dict[str, Any]],
        target_tokens: Dict[str, torch.LongTensor] = None,
        target_token_ids: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Make foward pass with decoder logic for producing the entire target sequence.
        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of `TextField.as_array()` applied on the source `TextField`. This will be
            passed through a `TextFieldEmbedder` and then through an encoder.
        source_token_ids : ``torch.Tensor``, required
            Tensor containing IDs that indicate which source tokens match each other.
            Has shape: `(batch_size, trimmed_source_length)`.
        source_to_target : ``torch.Tensor``, required
            Tensor containing vocab index of each source token with respect to the
            target vocab namespace. Shape: `(batch_size, trimmed_source_length)`.
        metadata : ``List[Dict[str, Any]]``, required
            Metadata field that contains the original source tokens with key 'source_tokens'
            and any other meta fields. When 'target_tokens' is also passed, the metadata
            should also contain the original target tokens with key 'target_tokens'.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
            Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
            target tokens are also represented as a `TextField` which must contain a "tokens"
            key that uses single ids.
        target_token_ids : ``torch.Tensor``, optional (default = None)
            A tensor of shape `(batch_size, target_sequence_length)` which indicates which
            tokens in the target sequence match tokens in the source sequence.
        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(source_tokens)
        state["source_token_ids"] = source_token_ids
        state["source_to_target"] = source_to_target

        if target_tokens:
            state = self._init_decoder_state(state)
            output_dict = self._forward_loss(target_tokens, target_token_ids,
                                             state)
        else:
            output_dict = {}

        output_dict["metadata"] = metadata

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            if target_tokens:
                if self._tensor_based_metric is not None:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]
                    # shape: (batch_size, target_sequence_length)
                    gold_tokens = self._gather_extended_gold_tokens(
                        target_tokens["tokens"], source_token_ids,
                        target_token_ids)
                    self._tensor_based_metric(best_predictions,
                                              gold_tokens)  # type: ignore
                if self._token_based_metric is not None:
                    predicted_tokens = self._get_predicted_tokens(
                        output_dict["predictions"], metadata, n_best=1)
                    self._token_based_metric(  # type: ignore
                        predicted_tokens,
                        [x["target_tokens"] for x in metadata])

        return output_dict

    def _gather_extended_gold_tokens(
        self,
        target_tokens: torch.Tensor,
        source_token_ids: torch.Tensor,
        target_token_ids: torch.Tensor,
    ) -> torch.LongTensor:
        """
        Modify the gold target tokens relative to the extended vocabulary.
        For gold targets that are OOV but were copied from the source, the OOV index
        will be changed to the index of the first occurence in the source sentence,
        offset by the size of the target vocabulary.
        Parameters
        ----------
        target_tokens : ``torch.Tensor``
            Shape: `(batch_size, target_sequence_length)`.
        source_token_ids : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`.
        target_token_ids : ``torch.Tensor``
            Shape: `(batch_size, target_sequence_length)`.
        Returns
        -------
        torch.Tensor
            Modified `target_tokens` with OOV indices replaced by offset index
            of first match in source sentence.
        """
        batch_size, target_sequence_length = target_tokens.size()
        trimmed_source_length = source_token_ids.size(1)
        # Only change indices for tokens that were OOV in target vocab but copied from source.
        # shape: (batch_size, target_sequence_length)
        oov = target_tokens == self._oov_index
        # shape: (batch_size, target_sequence_length, trimmed_source_length)
        expanded_source_token_ids = source_token_ids.unsqueeze(1).expand(
            batch_size, target_sequence_length, trimmed_source_length)
        # shape: (batch_size, target_sequence_length, trimmed_source_length)
        expanded_target_token_ids = target_token_ids.unsqueeze(-1).expand(
            batch_size, target_sequence_length, trimmed_source_length)
        # shape: (batch_size, target_sequence_length, trimmed_source_length)
        matches = expanded_source_token_ids == expanded_target_token_ids
        # shape: (batch_size, target_sequence_length)
        copied = matches.sum(-1) > 0
        # shape: (batch_size, target_sequence_length)
        mask = (oov & copied).long()
        # shape: (batch_size, target_sequence_length)
        first_match = ((matches.cumsum(-1) == 1) * matches).to(
            torch.uint8).argmax(-1)
        # shape: (batch_size, target_sequence_length)
        new_target_tokens = (
            target_tokens * (1 - mask) +
            (first_match.long() + self._target_vocab_size) * mask)
        return new_target_tokens

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Initialize the encoded state to be passed to the first decoding time step.
        """
        batch_size, _ = state["source_mask"].size()

        # Initialize the decoder hidden state with the final output of the encoder,
        # and the decoder context with zeros.
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"], state["source_mask"],
            self._encoder.is_bidirectional())
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(
            batch_size, self.decoder_output_dim)

        return state

    def _encode(
            self,
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Encode source input sentences.
        """
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder.forward(
            source_tokens['bert'], source_tokens['bert-offsets'])
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}

    def _decoder_step(
        self,
        last_predictions: torch.Tensor,
        selective_weights: torch.Tensor,
        state: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = state["source_mask"].float()
        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)
        # shape: (group_size, max_input_sequence_length)
        attentive_weights = self._attention(state["decoder_hidden"],
                                            state["encoder_outputs"],
                                            encoder_outputs_mask)
        # shape: (group_size, encoder_output_dim)
        attentive_read = util.weighted_sum(state["encoder_outputs"],
                                           attentive_weights)
        # shape: (group_size, encoder_output_dim)
        selective_read = util.weighted_sum(state["encoder_outputs"][:, 1:-1],
                                           selective_weights)
        # shape: (group_size, target_embedding_dim + encoder_output_dim * 2)
        decoder_input = torch.cat(
            (embedded_input, attentive_read, selective_read), -1)
        # shape: (group_size, decoder_input_dim)
        projected_decoder_input = self._input_projection_layer(decoder_input)

        state["decoder_hidden"], state["decoder_context"] = self._decoder_cell(
            projected_decoder_input,
            (state["decoder_hidden"], state["decoder_context"]))
        return state

    def _get_generation_scores(self,
                               state: Dict[str, torch.Tensor]) -> torch.Tensor:
        return self._output_generation_layer(state["decoder_hidden"])

    def _get_copy_scores(self, state: Dict[str, torch.Tensor]) -> torch.Tensor:
        # shape: (batch_size, max_input_sequence_length - 2, encoder_output_dim)
        trimmed_encoder_outputs = state["encoder_outputs"][:, 1:-1]
        # shape: (batch_size, max_input_sequence_length - 2, decoder_output_dim)
        copy_projection = self._output_copying_layer(trimmed_encoder_outputs)
        # shape: (batch_size, max_input_sequence_length - 2, decoder_output_dim)
        copy_projection = torch.tanh(copy_projection)
        # shape: (batch_size, max_input_sequence_length - 2)
        copy_scores = copy_projection.bmm(
            state["decoder_hidden"].unsqueeze(-1)).squeeze(-1)
        return copy_scores

    def _get_ll_contrib(
        self,
        generation_scores: torch.Tensor,
        generation_scores_mask: torch.Tensor,
        copy_scores: torch.Tensor,
        target_tokens: torch.Tensor,
        target_to_source: torch.Tensor,
        copy_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get the log-likelihood contribution from a single timestep.
        Parameters
        ----------
        generation_scores : ``torch.Tensor``
            Shape: `(batch_size, target_vocab_size)`
        generation_scores_mask : ``torch.Tensor``
            Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's.
        copy_scores : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`
        target_tokens : ``torch.Tensor``
            Shape: `(batch_size,)`
        target_to_source : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`
        copy_mask : ``torch.Tensor``
            Shape: `(batch_size, trimmed_source_length)`
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Shape: `(batch_size,), (batch_size, max_input_sequence_length)`
        """
        _, target_size = generation_scores.size()

        # The point of this mask is to just mask out all source token scores
        # that just represent padding. We apply the mask to the concatenation
        # of the generation scores and the copy scores to normalize the scores
        # correctly during the softmax.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        mask = torch.cat((generation_scores_mask, copy_mask), dim=-1)
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
        # Normalize generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        log_probs = util.masked_log_softmax(all_scores, mask)
        # Calculate the log probability (`copy_log_probs`) for each token in the source sentence
        # that matches the current target token. We use the sum of these copy probabilities
        # for matching tokens in the source sentence to get the total probability
        # for the target token. We also need to normalize the individual copy probabilities
        # to create `selective_weights`, which are used in the next timestep to create
        # a selective read state.
        # shape: (batch_size, trimmed_source_length)
        copy_log_probs = log_probs[:, target_size:] + (
            target_to_source.float() + 1e-45).log()
        # Since `log_probs[:, target_size]` gives us the raw copy log probabilities,
        # we use a non-log softmax to get the normalized non-log copy probabilities.
        selective_weights = util.masked_softmax(log_probs[:, target_size:],
                                                target_to_source)
        # This mask ensures that item in the batch has a non-zero generation probabilities
        # for this timestep only when the gold target token is not OOV or there are no
        # matching tokens in the source sentence.
        # shape: (batch_size, 1)
        gen_mask = ((target_tokens != self._oov_index) |
                    (target_to_source.sum(-1) == 0)).float()
        log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
        # Now we get the generation score for the gold target token.
        # shape: (batch_size, 1)
        generation_log_probs = log_probs.gather(
            1, target_tokens.unsqueeze(1)) + log_gen_mask
        # ... and add the copy score to get the step log likelihood.
        # shape: (batch_size, 1 + trimmed_source_length)
        combined_gen_and_copy = torch.cat(
            (generation_log_probs, copy_log_probs), dim=-1)
        # shape: (batch_size,)
        step_log_likelihood = util.logsumexp(combined_gen_and_copy)

        return step_log_likelihood, selective_weights

    def _forward_loss(
        self,
        target_tokens: Dict[str, torch.LongTensor],
        target_token_ids: torch.Tensor,
        state: Dict[str, torch.Tensor],
    ) -> Dict[str, torch.Tensor]:
        """
        Calculate the loss against gold targets.
        """
        batch_size, target_sequence_length = target_tokens["tokens"].size()

        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # The last input from the target is either padding or the end symbol.
        # Either way, we don't have to process it.
        num_decoding_steps = target_sequence_length - 1
        # We use this to fill in the copy index when the previous input was copied.
        # shape: (batch_size,)
        copy_input_choices = source_mask.new_full((batch_size, ),
                                                  fill_value=self._copy_index)
        # shape: (batch_size, trimmed_source_length)
        copy_mask = source_mask[:, 1:-1].float()
        # We need to keep track of the probabilities assigned to tokens in the source
        # sentence that were copied during the previous timestep, since we use
        # those probabilities as weights when calculating the "selective read".
        # shape: (batch_size, trimmed_source_length)
        selective_weights = state["decoder_hidden"].new_zeros(copy_mask.size())

        # Indicates which tokens in the source sentence match the current target token.
        # shape: (batch_size, trimmed_source_length)
        target_to_source = state["source_token_ids"].new_zeros(
            copy_mask.size())

        # This is just a tensor of ones which we use repeatedly in `self._get_ll_contrib`,
        # so we create it once here to avoid doing it over-and-over.
        generation_scores_mask = state["decoder_hidden"].new_full(
            (batch_size, self._target_vocab_size), fill_value=1.0)

        step_log_likelihoods = []
        for timestep in range(num_decoding_steps):
            # shape: (batch_size,)
            input_choices = target_tokens["tokens"][:, timestep]
            # If the previous target token was copied, we use the special copy token.
            # But the end target token will always be THE end token, so we know
            # it was not copied.
            if timestep < num_decoding_steps - 1:
                # Get mask tensor indicating which instances were copied.
                # shape: (batch_size,)
                copied = ((input_choices == self._oov_index) &
                          (target_to_source.sum(-1) > 0)).long()
                # shape: (batch_size,)
                input_choices = input_choices * (
                    1 - copied) + copy_input_choices * copied
                # shape: (batch_size, trimmed_source_length)
                target_to_source = state[
                    "source_token_ids"] == target_token_ids[:, timestep +
                                                            1].unsqueeze(-1)
            # Update the decoder state by taking a step through the RNN.
            state = self._decoder_step(input_choices, selective_weights, state)
            # Get generation scores for each token in the target vocab.
            # shape: (batch_size, target_vocab_size)
            generation_scores = self._get_generation_scores(state)
            # Get copy scores for each token in the source sentence, excluding the start
            # and end tokens.
            # shape: (batch_size, trimmed_source_length)
            copy_scores = self._get_copy_scores(state)
            # shape: (batch_size,)
            step_target_tokens = target_tokens["tokens"][:, timestep + 1]
            step_log_likelihood, selective_weights = self._get_ll_contrib(
                generation_scores,
                generation_scores_mask,
                copy_scores,
                step_target_tokens,
                target_to_source,
                copy_mask,
            )
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))

        # Gather step log-likelihoods.
        # shape: (batch_size, num_decoding_steps = target_sequence_length - 1)
        log_likelihoods = torch.cat(step_log_likelihoods, 1)
        # Get target mask to exclude likelihood contributions from timesteps after
        # the END token.
        # shape: (batch_size, target_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)
        # The first timestep is just the START token, which is not included in the likelihoods.
        # shape: (batch_size, num_decoding_steps)
        target_mask = target_mask[:, 1:].float()
        # Sum of step log-likelihoods.
        # shape: (batch_size,)
        log_likelihood = (log_likelihoods * target_mask).sum(dim=-1)
        # The loss is the negative log-likelihood, averaged over the batch.
        loss = -log_likelihood.sum() / batch_size

        return {"loss": loss}

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size, source_length = state["source_mask"].size()
        trimmed_source_length = source_length - 2
        # Initialize the copy scores to zero.
        state["copy_log_probs"] = (state["decoder_hidden"].new_zeros(
            (batch_size, trimmed_source_length)) + 1e-45).log()
        # shape: (batch_size,)
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)
        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_search_step)
        return {
            "predicted_log_probs": log_probabilities,
            "predictions": all_top_k_predictions
        }

    def _get_input_and_selective_weights(
        self, last_predictions: torch.LongTensor,
        state: Dict[str,
                    torch.Tensor]) -> Tuple[torch.LongTensor, torch.Tensor]:
        """
        Get input choices for the decoder and the selective copy weights.
        The decoder input choices are simply the `last_predictions`, except for
        target OOV predictions that were copied from source tokens, in which case
        the prediction will be changed to the COPY symbol in the target namespace.
        The selective weights are just the probabilities assigned to source
        tokens that were copied, normalized to sum to 1. If no source tokens were copied,
        there will be all zeros.
        Parameters
        ----------
        last_predictions : ``torch.LongTensor``
            Shape: `(group_size,)`
        state : ``Dict[str, torch.Tensor]``
        Returns
        -------
        Tuple[torch.LongTensor, torch.Tensor]
            `input_choices` (shape `(group_size,)`) and `selective_weights`
            (shape `(group_size, trimmed_source_length)`).
        """
        group_size, trimmed_source_length = state["source_to_target"].size()

        # This is a mask indicating which last predictions were copied from the
        # the source AND not in the target vocabulary (OOV).
        # (group_size,)
        only_copied_mask = (last_predictions >= self._target_vocab_size).long()

        # If the last prediction was in the target vocab or OOV but not copied,
        # we use that as input, otherwise we use the COPY token.
        # shape: (group_size,)
        copy_input_choices = only_copied_mask.new_full(
            (group_size, ), fill_value=self._copy_index)
        input_choices = (last_predictions * (1 - only_copied_mask) +
                         copy_input_choices * only_copied_mask)

        # In order to get the `selective_weights`, we need to find out which predictions
        # were copied or copied AND generated, which is the case when a prediction appears
        # in both the source sentence and the target vocab. But whenever a prediction
        # is in the target vocab (even if it also appeared in the source sentence),
        # its index will be the corresponding target vocab index, not its index in
        # the source sentence offset by the target vocab size. So we first
        # use `state["source_to_target"]` to get an indicator of every source token
        # that matches the predicted target token.
        # shape: (group_size, trimmed_source_length)
        expanded_last_predictions = last_predictions.unsqueeze(-1).expand(
            group_size, trimmed_source_length)
        # shape: (group_size, trimmed_source_length)
        source_copied_and_generated = (
            state["source_to_target"] == expanded_last_predictions).long()

        # In order to get indicators for copied source tokens that are OOV with respect
        # to the target vocab, we'll make use of `state["source_token_ids"]`.
        # First we adjust predictions relative to the start of the source tokens.
        # This makes sense because predictions for copied tokens are given by the index of the copied
        # token in the source sentence, offset by the size of the target vocabulary.
        # shape: (group_size,)
        adjusted_predictions = last_predictions - self._target_vocab_size
        # The adjusted indices for items that were not copied will be negative numbers,
        # and therefore invalid. So we zero them out.
        adjusted_predictions = adjusted_predictions * only_copied_mask
        # shape: (group_size, trimmed_source_length)
        source_token_ids = state["source_token_ids"]
        # shape: (group_size, trimmed_source_length)
        adjusted_prediction_ids = source_token_ids.gather(
            -1, adjusted_predictions.unsqueeze(-1))
        # This mask will contain indicators for source tokens that were copied
        # during the last timestep.
        # shape: (group_size, trimmed_source_length)
        source_only_copied = (
            source_token_ids == adjusted_prediction_ids).long()
        # Since we zero'd-out indices for predictions that were not copied,
        # we need to zero out all entries of this mask corresponding to those predictions.
        source_only_copied = source_only_copied * only_copied_mask.unsqueeze(
            -1)

        # shape: (group_size, trimmed_source_length)
        mask = source_only_copied | source_copied_and_generated
        # shape: (group_size, trimmed_source_length)
        selective_weights = util.masked_softmax(state["copy_log_probs"], mask)

        return input_choices, selective_weights

    def _gather_final_log_probs(
        self,
        generation_log_probs: torch.Tensor,
        copy_log_probs: torch.Tensor,
        state: Dict[str, torch.Tensor],
    ) -> torch.Tensor:
        """
        Combine copy probabilities with generation probabilities for matching tokens.
        Parameters
        ----------
        generation_log_probs : ``torch.Tensor``
            Shape: `(group_size, target_vocab_size)`
        copy_log_probs : ``torch.Tensor``
            Shape: `(group_size, trimmed_source_length)`
        state : ``Dict[str, torch.Tensor]``
        Returns
        -------
        torch.Tensor
            Shape: `(group_size, target_vocab_size + trimmed_source_length)`.
        """
        _, trimmed_source_length = state["source_to_target"].size()
        source_token_ids = state["source_token_ids"]

        # shape: [(batch_size, *)]
        modified_log_probs_list: List[torch.Tensor] = []
        for i in range(trimmed_source_length):
            # shape: (group_size,)
            copy_log_probs_slice = copy_log_probs[:, i]
            # `source_to_target` is a matrix of shape (group_size, trimmed_source_length)
            # where element (i, j) is the vocab index of the target token that matches the jth
            # source token in the ith group, if there is one, or the index of the OOV symbol otherwise.
            # We'll use this to add copy scores to corresponding generation scores.
            # shape: (group_size,)
            source_to_target_slice = state["source_to_target"][:, i]
            # The OOV index in the source_to_target_slice indicates that the source
            # token is not in the target vocab, so we don't want to add that copy score
            # to the OOV token.
            copy_log_probs_to_add_mask = (source_to_target_slice !=
                                          self._oov_index).float()
            copy_log_probs_to_add = (
                copy_log_probs_slice +
                (copy_log_probs_to_add_mask + 1e-45).log())
            # shape: (batch_size, 1)
            copy_log_probs_to_add = copy_log_probs_to_add.unsqueeze(-1)
            # shape: (batch_size, 1)
            selected_generation_log_probs = generation_log_probs.gather(
                1, source_to_target_slice.unsqueeze(-1))
            combined_scores = util.logsumexp(
                torch.cat(
                    (selected_generation_log_probs, copy_log_probs_to_add),
                    dim=1))
            generation_log_probs = generation_log_probs.scatter(
                -1, source_to_target_slice.unsqueeze(-1),
                combined_scores.unsqueeze(-1))
            # We have to combine copy scores for duplicate source tokens so that
            # we can find the overall most likely source token. So, if this is the first
            # occurence of this particular source token, we add the log_probs from all other
            # occurences, otherwise we zero it out since it was already accounted for.
            if i < (trimmed_source_length - 1):
                # Sum copy scores from future occurences of source token.
                # shape: (group_size, trimmed_source_length - i)
                source_future_occurences = (source_token_ids[:, (
                    i +
                    1):] == source_token_ids[:,
                                             i].unsqueeze(-1)).float()  # noqa
                # shape: (group_size, trimmed_source_length - i)
                future_copy_log_probs = (
                    copy_log_probs[:, (i + 1):] +
                    (source_future_occurences + 1e-45).log())
                # shape: (group_size, 1 + trimmed_source_length - i)
                combined = torch.cat((copy_log_probs_slice.unsqueeze(-1),
                                      future_copy_log_probs),
                                     dim=-1)
                # shape: (group_size,)
                copy_log_probs_slice = util.logsumexp(combined)
            if i > 0:
                # Remove copy log_probs that we have already accounted for.
                # shape: (group_size, i)
                source_previous_occurences = source_token_ids[:, 0:
                                                              i] == source_token_ids[:, i].unsqueeze(
                                                                  -1)
                # shape: (group_size,)
                duplicate_mask = (source_previous_occurences.sum(
                    dim=-1) == 0).float()
                copy_log_probs_slice = copy_log_probs_slice + (duplicate_mask +
                                                               1e-45).log()

            # Finally, we zero-out copy scores that we added to the generation scores
            # above so that we don't double-count them.
            # shape: (group_size,)
            left_over_copy_log_probs = (
                copy_log_probs_slice +
                (1.0 - copy_log_probs_to_add_mask + 1e-45).log())
            modified_log_probs_list.append(
                left_over_copy_log_probs.unsqueeze(-1))
        modified_log_probs_list.insert(0, generation_log_probs)

        # shape: (group_size, target_vocab_size + trimmed_source_length)
        modified_log_probs = torch.cat(modified_log_probs_list, dim=-1)

        return modified_log_probs

    def take_search_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take step during beam search.
        This function is what gets passed to the `BeamSearch.search` method. It takes
        predictions from the last timestep and the current state and outputs
        the log probabilities assigned to tokens for the next timestep, as well as the updated
        state.
        Since we are predicting tokens out of the extended vocab (target vocab + all unique
        tokens from the source sentence), this is a little more complicated that just
        making a forward pass through the model. The output log probs will have
        shape `(group_size, target_vocab_size + trimmed_source_length)` so that each
        token in the target vocab and source sentence are assigned a probability.
        Note that copy scores are assigned to each source token based on their position, not unique value.
        So if a token appears more than once in the source sentence, it will have more than one score.
        Further, if a source token is also part of the target vocab, its final score
        will be the sum of the generation and copy scores. Therefore, in order to
        get the score for all tokens in the extended vocab at this step,
        we have to combine copy scores for re-occuring source tokens and potentially
        add them to the generation scores for the matching token in the target vocab, if
        there is one.
        So we can break down the final log probs output as the concatenation of two
        matrices, A: `(group_size, target_vocab_size)`, and B: `(group_size, trimmed_source_length)`.
        Matrix A contains the sum of the generation score and copy scores (possibly 0)
        for each target token. Matrix B contains left-over copy scores for source tokens
        that do NOT appear in the target vocab, with zeros everywhere else. But since
        a source token may appear more than once in the source sentence, we also have to
        sum the scores for each appearance of each unique source token. So matrix B
        actually only has non-zero values at the first occurence of each source token
        that is not in the target vocab.
        Parameters
        ----------
        last_predictions : ``torch.Tensor``
            Shape: `(group_size,)`
        state : ``Dict[str, torch.Tensor]``
            Contains all state tensors necessary to produce generation and copy scores
            for next step.
        Notes
        -----
        `group_size` != `batch_size`. In fact, `group_size` = `batch_size * beam_size`.
        """
        _, trimmed_source_length = state["source_to_target"].size()

        # Get input to the decoder RNN and the selective weights. `input_choices`
        # is the result of replacing target OOV tokens in `last_predictions` with the
        # copy symbol. `selective_weights` consist of the normalized copy probabilities
        # assigned to the source tokens that were copied. If no tokens were copied,
        # there will be all zeros.
        # shape: (group_size,), (group_size, trimmed_source_length)
        input_choices, selective_weights = self._get_input_and_selective_weights(
            last_predictions, state)
        # Update the decoder state by taking a step through the RNN.
        state = self._decoder_step(input_choices, selective_weights, state)
        # Get the un-normalized generation scores for each token in the target vocab.
        # shape: (group_size, target_vocab_size)
        generation_scores = self._get_generation_scores(state)
        # Get the un-normalized copy scores for each token in the source sentence,
        # excluding the start and end tokens.
        # shape: (group_size, trimmed_source_length)
        copy_scores = self._get_copy_scores(state)
        # Concat un-normalized generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        all_scores = torch.cat((generation_scores, copy_scores), dim=-1)
        # shape: (group_size, trimmed_source_length)
        copy_mask = state["source_mask"][:, 1:-1].float()
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        mask = torch.cat((generation_scores.new_full(generation_scores.size(),
                                                     1.0), copy_mask),
                         dim=-1)
        # Normalize generation and copy scores.
        # shape: (batch_size, target_vocab_size + trimmed_source_length)
        log_probs = util.masked_log_softmax(all_scores, mask)
        # shape: (group_size, target_vocab_size), (group_size, trimmed_source_length)
        generation_log_probs, copy_log_probs = log_probs.split(
            [self._target_vocab_size, trimmed_source_length], dim=-1)
        # Update copy_probs needed for getting the `selective_weights` at the next timestep.
        state["copy_log_probs"] = copy_log_probs
        # We now have normalized generation and copy scores, but to produce the final
        # score for each token in the extended vocab, we have to go through and add
        # the copy scores to the generation scores of matching target tokens, and sum
        # the copy scores of duplicate source tokens.
        # shape: (group_size, target_vocab_size + trimmed_source_length)
        final_log_probs = self._gather_final_log_probs(generation_log_probs,
                                                       copy_log_probs, state)

        return final_log_probs, state

    def _get_predicted_tokens(
        self,
        predicted_indices: Union[torch.Tensor, np.ndarray],
        batch_metadata: List[Any],
        n_best: int = None,
    ) -> List[Union[List[List[str]], List[str]]]:
        """
        Convert predicted indices into tokens.
        If `n_best = 1`, the result type will be `List[List[str]]`. Otherwise the result
        type will be `List[List[List[str]]]`.
        """
        if not isinstance(predicted_indices, np.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        predicted_tokens: List[Union[List[List[str]], List[str]]] = []
        for top_k_predictions, metadata in zip(predicted_indices,
                                               batch_metadata):
            batch_predicted_tokens: List[List[str]] = []
            for indices in top_k_predictions[:n_best]:
                tokens: List[str] = []
                indices = list(indices)
                if self._end_index in indices:
                    indices = indices[:indices.index(self._end_index)]
                for index in indices:
                    if index >= self._target_vocab_size:
                        adjusted_index = index - self._target_vocab_size
                        token = metadata["source_tokens"][adjusted_index]
                    else:
                        token = self.vocab.get_token_from_index(
                            index, self._target_namespace)
                    tokens.append(token)
                batch_predicted_tokens.append(tokens)
            if n_best == 1:
                predicted_tokens.append(batch_predicted_tokens[0])
            else:
                predicted_tokens.append(batch_predicted_tokens)
        return predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        Finalize predictions.
        After a beam search, the predicted indices correspond to tokens in the target vocabulary
        OR tokens in source sentence. Here we gather the actual tokens corresponding to
        the indices.
        """
        predicted_tokens = self._get_predicted_tokens(
            output_dict["predictions"], output_dict["metadata"])
        output_dict["predicted_tokens"] = predicted_tokens
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            if self._tensor_based_metric is not None:
                all_metrics.update(
                    self._tensor_based_metric.get_metric(
                        reset=reset)  # type: ignore
                )
            if self._token_based_metric is not None:
                all_metrics.update(
                    self._token_based_metric.get_metric(
                        reset=reset))  # type: ignore
        return all_metrics
Beispiel #32
0
class PointerGeneratorNetwork(Model):
    """
    Based on https://arxiv.org/pdf/1704.04368.pdf
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 max_decoding_steps: int,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 projection_dim: int = None,
                 use_coverage: bool = False,
                 coverage_shift: float = 0.,
                 coverage_loss_weight: float = None,
                 embed_attn_to_output: bool = False) -> None:
        super(PointerGeneratorNetwork, self).__init__(vocab)

        self._target_namespace = target_namespace
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     target_namespace)
        self._unk_index = self.vocab.get_token_index(DEFAULT_OOV_TOKEN,
                                                     target_namespace)
        self._vocab_size = self.vocab.get_vocab_size(target_namespace)
        assert self._vocab_size > 2, \
            "Target vocabulary is empty. Make sure 'target_namespace' option of the model is correct."

        # Encoder
        self._source_embedder = source_embedder
        self._encoder = encoder
        self._encoder_output_dim = self._encoder.get_output_dim()

        # Decoder
        self._target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        )
        self._num_classes = self.vocab.get_vocab_size(target_namespace)
        self._target_embedder = Embedding(self._num_classes,
                                          self._target_embedding_dim)

        self._decoder_input_dim = self._encoder_output_dim + self._target_embedding_dim
        self._decoder_output_dim = self._encoder_output_dim
        self._decoder_cell = LSTMCell(self._decoder_input_dim,
                                      self._decoder_output_dim)

        self._projection_dim = projection_dim or self._source_embedder.get_output_dim(
        )
        hidden_projection_dim = self._decoder_output_dim if not embed_attn_to_output else self._decoder_output_dim * 2
        self._hidden_projection_layer = Linear(hidden_projection_dim,
                                               self._projection_dim)
        self._output_projection_layer = Linear(self._projection_dim,
                                               self._num_classes)

        self._p_gen_layer = Linear(
            self._decoder_output_dim * 3 + self._decoder_input_dim, 1)
        self._attention = attention
        self._use_coverage = use_coverage
        self._coverage_loss_weight = coverage_loss_weight
        self._eps = 1e-31
        self._embed_attn_to_output = embed_attn_to_output
        self._coverage_shift = coverage_shift

        # Metrics
        self._p_gen_sum = 0.0
        self._p_gen_iterations = 0
        self._coverage_loss_sum = 0.0
        self._coverage_iterations = 0

        # Decoding
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size or 1)

    def forward(self,
                source_tokens: Dict[str, torch.LongTensor],
                source_token_ids: torch.Tensor,
                source_to_target: torch.LongTensor,
                target_tokens: Dict[str, torch.LongTensor] = None,
                target_token_ids: torch.Tensor = None,
                metadata=None) -> Dict[str, torch.Tensor]:
        state = self._encode(source_tokens)
        target_tokens_tensor = target_tokens["tokens"].long(
        ) if target_tokens else None
        extra_zeros, modified_source_tokens, modified_target_tokens = self._prepare(
            source_to_target, source_token_ids, target_tokens_tensor,
            target_token_ids)

        state["tokens"] = modified_source_tokens
        state["extra_zeros"] = extra_zeros

        output_dict = {}
        if target_tokens:
            state["target_tokens"] = modified_target_tokens
            state = self._init_decoder_state(state)
            output_dict = self._forward_loop(state, target_tokens)
        output_dict["metadata"] = metadata
        output_dict["source_to_target"] = source_to_target

        if not self.training:
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)

        return output_dict

    def _prepare(self,
                 source_tokens: torch.LongTensor,
                 source_token_ids: torch.Tensor,
                 target_tokens: torch.LongTensor = None,
                 target_token_ids: torch.Tensor = None):
        batch_size = source_tokens.size(0)
        source_max_length = source_tokens.size(1)

        tokens = source_tokens
        token_ids = source_token_ids.long()

        # Concat target tokens if exist
        if target_tokens is not None:
            tokens = torch.cat((tokens, target_tokens), 1)
            token_ids = torch.cat((token_ids, target_token_ids.long()), 1)

        is_unk = torch.eq(tokens, self._unk_index).long()
        # Create tensor with ids of unknown tokens only.
        # Those ids are batch-local.
        unk_only = token_ids * is_unk

        # Recalculate batch-local ids to range [1, count_of_unique_unk_tokens].
        # All known tokens have zero id.
        unk_token_nums = token_ids.new_zeros((batch_size, token_ids.size(1)))
        for i in range(batch_size):
            unique = torch.unique(unk_only[i, :],
                                  return_inverse=True,
                                  sorted=True)[1]
            unk_token_nums[i, :] = unique

        # Replace DEFAULT_OOV_TOKEN id with new batch-local ids starting from vocab_size
        # For example, if vocabulary size is 50000, the first unique unknown token will have 50000 index,
        # the second will have 50001 index and so on.
        tokens = tokens - tokens * is_unk + (self._vocab_size -
                                             1) * is_unk + unk_token_nums

        modified_target_tokens = None
        modified_source_tokens = tokens
        if target_tokens is not None:
            # Remove target unknown tokens that do not exist in source tokens
            max_source_num = torch.max(tokens[:, :source_max_length], dim=1)[0]
            vocab_size = max_source_num.new_full((1, ), self._vocab_size - 1)
            max_source_num = torch.max(max_source_num,
                                       other=vocab_size).unsqueeze(1).expand(
                                           (-1, tokens.size(1)))
            unk_target_tokens_mask = torch.gt(tokens, max_source_num).long()
            tokens = tokens - tokens * unk_target_tokens_mask + self._unk_index * unk_target_tokens_mask
            modified_target_tokens = tokens[:, source_max_length:]
            modified_source_tokens = tokens[:, :source_max_length]

        # Count unique unknown source tokens to create enough zeros for final distribution
        source_unk_count = torch.max(unk_token_nums[:, :source_max_length])
        extra_zeros = tokens.new_zeros((batch_size, source_unk_count),
                                       dtype=torch.float32)
        return extra_zeros, modified_source_tokens, modified_target_tokens

    def _encode(
            self,
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder.forward(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder.forward(embedded_input, source_mask)

        return {
            "source_mask": source_mask,
            "encoder_outputs": encoder_outputs,
        }

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"], state["source_mask"],
            self._encoder.is_bidirectional())
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output

        encoder_outputs = state["encoder_outputs"]
        state["decoder_context"] = encoder_outputs.new_zeros(
            batch_size, self._decoder_output_dim)
        if self._embed_attn_to_output:
            state["attn_context"] = encoder_outputs.new_zeros(
                encoder_outputs.size(0), encoder_outputs.size(2))
        if self._use_coverage:
            state["coverage"] = encoder_outputs.new_zeros(
                batch_size, encoder_outputs.size(1))
        return state

    def _prepare_output_projections(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]
        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]
        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]
        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]
        # shape: (group_size, decoder_output_dim)
        attn_context = state.get("attn_context", None)

        is_unk = (last_predictions >= self._vocab_size).long()
        last_predictions_fixed = last_predictions - last_predictions * is_unk + self._unk_index * is_unk
        embedded_input = self._target_embedder(last_predictions_fixed)

        coverage = state.get("coverage", None)

        def get_attention_context(decoder_hidden_inner):
            if coverage is None:
                attention_scores = self._attention(decoder_hidden_inner,
                                                   encoder_outputs,
                                                   source_mask)
            else:
                attention_scores = self._attention(decoder_hidden_inner,
                                                   encoder_outputs,
                                                   source_mask, coverage)
            attention_context = util.weighted_sum(encoder_outputs,
                                                  attention_scores)
            return attention_scores, attention_context

        if not self._embed_attn_to_output:
            attn_scores, attn_context = get_attention_context(decoder_hidden)
            decoder_input = torch.cat((attn_context, embedded_input), -1)
            decoder_hidden, decoder_context = self._decoder_cell(
                decoder_input, (decoder_hidden, decoder_context))
            projection = self._hidden_projection_layer(decoder_hidden)
        else:
            decoder_input = torch.cat((attn_context, embedded_input), -1)
            decoder_hidden, decoder_context = self._decoder_cell(
                decoder_input, (decoder_hidden, decoder_context))
            attn_scores, attn_context = get_attention_context(decoder_hidden)
            projection = self._hidden_projection_layer(
                torch.cat((attn_context, decoder_hidden), -1))

        output_projections = self._output_projection_layer(projection)
        if self._use_coverage:
            state["coverage"] = coverage + attn_scores
        state["decoder_input"] = decoder_input
        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context
        state["attn_scores"] = attn_scores
        state["attn_context"] = attn_context

        return output_projections, state

    def _get_final_dist(self, state: Dict[str, torch.Tensor],
                        output_projections):
        attn_dist = state["attn_scores"]
        tokens = state["tokens"]
        extra_zeros = state["extra_zeros"]
        attn_context = state["attn_context"]
        decoder_input = state["decoder_input"]
        decoder_hidden = state["decoder_hidden"]
        decoder_context = state["decoder_context"]

        decoder_state = torch.cat((decoder_hidden, decoder_context), 1)
        p_gen = self._p_gen_layer(
            torch.cat((attn_context, decoder_state, decoder_input), 1))
        p_gen = torch.sigmoid(p_gen)
        self._p_gen_sum += torch.mean(p_gen).item()
        self._p_gen_iterations += 1

        vocab_dist = F.softmax(output_projections, dim=-1)

        vocab_dist = vocab_dist * p_gen
        attn_dist = attn_dist * (1.0 - p_gen)
        if extra_zeros.size(1) != 0:
            vocab_dist = torch.cat((vocab_dist, extra_zeros), 1)
        final_dist = vocab_dist.scatter_add(1, tokens, attn_dist)
        normalization_factor = final_dist.sum(1, keepdim=True)
        final_dist = final_dist / normalization_factor

        return final_dist

    def _forward_loop(
        self,
        state: Dict[str, torch.Tensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]
        batch_size = source_mask.size(0)

        num_decoding_steps = self._max_decoding_steps
        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]
            _, target_sequence_length = targets.size()
            num_decoding_steps = target_sequence_length - 1

        if self._use_coverage:
            coverage_loss = source_mask.new_zeros(1, dtype=torch.float32)

        last_predictions = state["tokens"].new_full(
            (batch_size, ), fill_value=self._start_index)
        step_proba: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if self.training and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                input_choices = last_predictions
            elif not target_tokens:
                input_choices = last_predictions
            else:
                input_choices = targets[:, timestep]

            if self._use_coverage:
                old_coverage = state["coverage"]

            output_projections, state = self._prepare_output_projections(
                input_choices, state)
            final_dist = self._get_final_dist(state, output_projections)
            step_proba.append(final_dist)
            last_predictions = torch.max(final_dist, 1)[1]
            step_predictions.append(last_predictions.unsqueeze(1))

            if self._use_coverage:
                step_coverage_loss = torch.sum(
                    torch.min(state["attn_scores"], old_coverage), 1)
                coverage_loss = coverage_loss + step_coverage_loss

        # shape: (batch_size, num_decoding_steps)
        predictions = torch.cat(step_predictions, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            num_classes = step_proba[0].size(1)
            proba = step_proba[0].new_zeros(
                (batch_size, num_classes, len(step_proba)))
            for i, p in enumerate(step_proba):
                proba[:, :, i] = p

            loss = self._get_loss(proba, state["target_tokens"], self._eps)
            if self._use_coverage:
                coverage_loss = torch.mean(coverage_loss / num_decoding_steps)
                self._coverage_loss_sum += coverage_loss.item()
                self._coverage_iterations += 1
                modified_coverage_loss = relu(
                    coverage_loss -
                    self._coverage_shift) + self._coverage_shift - 1.0
                loss = loss + self._coverage_loss_weight * modified_coverage_loss
            output_dict["loss"] = loss

        return output_dict

    @staticmethod
    def _get_loss(proba: torch.LongTensor, targets: torch.LongTensor,
                  eps: float) -> torch.Tensor:
        targets = targets[:, 1:]
        proba = torch.log(proba + eps)
        loss = torch.nn.NLLLoss(ignore_index=0)(proba, targets)
        return loss

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["tokens"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)
        final_dist = self._get_final_dist(state, output_projections)
        log_probabilities = torch.log(final_dist + self._eps)
        return log_probabilities, state

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, np.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        all_meta = output_dict["metadata"]
        all_source_to_target = output_dict["source_to_target"]
        for (indices, metadata), source_to_target in zip(
                zip(predicted_indices, all_meta), all_source_to_target):
            all_predicted_tokens.append(
                self._decode_sample(indices, metadata, source_to_target))
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _decode_sample(self, indices, metadata, source_to_target):
        predicted_tokens = []
        # Beam search gives us the top k results for each source sentence in the batch
        # but we just want the single best.
        if len(indices.shape) > 1:
            indices = indices[0]
        indices = list(indices)
        # Collect indices till the first end_symbol
        if self._end_index in indices:
            indices = indices[:indices.index(self._end_index)]
        # Get all unknown tokens from source
        original_source_tokens = metadata["source_tokens"]
        unk_tokens = list()
        for i, token_vocab_index in enumerate(source_to_target):
            if token_vocab_index != self._unk_index:
                continue
            token = original_source_tokens[i]
            if token in unk_tokens:
                continue
            unk_tokens.append(token)

        for token_vocab_index in indices:
            if token_vocab_index < self._vocab_size:
                token = self.vocab.get_token_from_index(
                    token_vocab_index, namespace=self._target_namespace)
            else:
                unk_number = token_vocab_index - self._vocab_size
                token = unk_tokens[unk_number]
            predicted_tokens.append(token)
        return predicted_tokens

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        if not self._use_coverage:
            return {}
        avg_coverage_loss = self._coverage_loss_sum / self._coverage_iterations if self._coverage_iterations != 0 else 0.0
        avg_p_gen = self._p_gen_sum / self._p_gen_iterations if self._p_gen_iterations != 0 else 0.0
        metrics = {"coverage_loss": avg_coverage_loss, "p_gen": avg_p_gen}
        if reset:
            self._p_gen_sum = 0.0
            self._p_gen_iterations = 0
            self._coverage_loss_sum = 0.0
            self._coverage_iterations = 0
        return metrics
Beispiel #33
0
class SequenceTransformer(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 target_embedder: Embedding,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 decoding_dim: int,
                 feedforward_hidden_dim: int,
                 num_layers: int,
                 num_attention_heads: int,
                 use_positional_encoding: bool = True,
                 positional_encoding_max_steps: int = 5000,
                 dropout_prob: float = 0.1,
                 residual_dropout_prob: float = 0.2,
                 attention_dropout_prob: float = 0.2,
                 beam_size: int = 1,
                 target_namespace: str = "tokens",
                 label_smoothing_ratio: Optional[float] = None,
                 initializer: Optional[InitializerApplicator] = None) -> None:
        super(SequenceTransformer, self).__init__(vocab)

        self._target_namespace = target_namespace
        self._label_smoothing_ratio = label_smoothing_ratio
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
                                                       self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
                                                     self._target_namespace)
        self._token_based_metric = TokenSequenceAccuracy()

        # Beam Search
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,
                                       max_steps=max_decoding_steps,
                                       beam_size=beam_size)

        # Encoder
        self._encoder = encoder

        # Vocabulary and embedder
        self._source_embedder = source_embedder
        self._target_embedder = target_embedder

        target_vocab_size = self.vocab.get_vocab_size(self._target_namespace)
        assert target_vocab_size == self._target_embedder.num_embeddings

        target_embedding_dim = self._target_embedder.get_output_dim()

        self._decoding_dim = decoding_dim
        # Sequence Decoder Features
        self._output_projection_layer = Linear(self._decoding_dim,
                                               target_vocab_size)

        self._decoder = Decoder(
            num_layers=num_layers,
            decoding_dim=decoding_dim,
            target_embedding_dim=target_embedding_dim,
            feedforward_hidden_dim=feedforward_hidden_dim,
            num_attention_heads=num_attention_heads,
            use_positional_encoding=use_positional_encoding,
            positional_encoding_max_steps=positional_encoding_max_steps,
            dropout_prob=dropout_prob,
            residual_dropout_prob=residual_dropout_prob,
            attention_dropout_prob=attention_dropout_prob)

        # Parameter checks and cleanup
        if self._target_embedder.get_output_dim(
        ) != self._decoder.target_embedding_dim:
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input."
            )
        #
        if self._encoder.get_output_dim() != self._decoder.get_output_dim():
            raise ConfigurationError(
                f"Encoder output dimension {self._encoder.get_output_dim()} should be"
                f" equal to decoder dimension {self._self_attention.get_output_dim()}."
            )

        if initializer:
            initializer(self)

        # Print the model
        print(self)

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Take a decoding step. This is called by the beam search class.
        # Parameters
        last_predictions : `torch.Tensor`
            A tensor of shape `(group_size,)`, which gives the indices of the predictions
            during the last time step.
        state : `Dict[str, torch.Tensor]`
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape `(group_size, *)`, where `*` can be any other number
            of dimensions.
        # Returns
        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of `(log_probabilities, updated_state)`, where `log_probabilities`
            is a tensor of shape `(group_size, num_classes)` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while `updated_state` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.
        Notes
        -----
            We treat the inputs as a batch, even though `group_size` is not necessarily
            equal to `batch_size`, since the group may contain multiple states
            for each source sentence in the batch.
        """
        # shape: (group_size, num_classes)
        output_projections, state = self._decoder_step(last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    @overrides
    def forward(
        self,  # type: ignore
        source_tokens: Dict[str, torch.LongTensor],
        metadata: List[Dict[str, Any]],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Make forward pass with decoder logic for producing the entire target sequence.
        Parameters
        ----------
        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        metadata: List[Dict[str, Any]]
            Additional information for prediction
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.
        Returns
        -------
        Dict[str, torch.Tensor]
        """
        state = self._encode(source_tokens)

        if target_tokens:
            # state = self._decoder.init_decoder_state(state)
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, target_tokens)
        else:
            output_dict = {}

        if not self.training:
            # state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            output_dict.update(predictions)
            if target_tokens:
                # shape: (batch_size, max_predicted_sequence_length)
                predicted_tokens = self.decode(output_dict)["predicted_tokens"]

                self._token_based_metric(
                    predicted_tokens, [x["target_tokens"] for x in metadata])

        return output_dict

    @overrides
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Finalize predictions.
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        """
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(
                    x, namespace=self._target_namespace) for x in indices
            ]
            all_predicted_tokens.append(predicted_tokens)
        output_dict["predicted_tokens"] = all_predicted_tokens  # type: ignore
        return output_dict

    def _encode(
            self,
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
            Make forward pass on the encoder.
            # Parameters
            source_tokens : `Dict[str, torch.Tensor]`
               The output of `TextField.as_array()` applied on the source `TextField`. This will be
               passed through a `TextFieldEmbedder` and then through an encoder.
            # Returns
            Dict[str, torch.Tensor]
                Map consisting of the key `source_mask` with the mask over the
                `source_tokens` text field,
                and the key `encoder_outputs` with the output tensor from
                forward pass on the encoder.
            """
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}

    def _forward_loop(
        self,
        state: Dict[str, torch.Tensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:

        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (batch_size, max_target_sequence_length)
        targets = target_tokens["tokens"]

        _, target_sequence_length = targets.size()

        # Prepare embeddings for targets. They will be used as gold embeddings during decoder training
        # shape: (batch_size, max_target_sequence_length, embedding_dim)
        target_embedding = self._target_embedder(targets)

        # shape: (batch_size, max_target_batch_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)

        _, decoder_output = self._decoder(
            previous_state=state,
            previous_steps_predictions=target_embedding[:, :-1, :],
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_mask=target_mask[:, :-1])

        # shape: (group_size, max_target_sequence_length, num_classes)
        logits = self._output_projection_layer(decoder_output).type(
            torch.FloatTensor)

        # Compute loss.
        loss = self._get_loss(logits, targets, target_mask)
        output_dict = {"loss": loss}

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Prepare inputs for the beam search, does beam search and returns beam search results.
        """
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities = self._beam_search.search(
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        }
        return output_dict

    def _decoder_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.
        Inputs are the same as for `take_step()`.
        """
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, steps_count, decoder_output_dim)
        previous_steps_predictions = state.get("previous_steps_predictions")

        # shape: (batch_size, 1, target_embedding_dim)
        last_predictions_embeddings = self._target_embedder(
            last_predictions).unsqueeze(1)

        if previous_steps_predictions is None or previous_steps_predictions.shape[
                -1] == 0:
            # There is no previous steps, except for start vectors in `last_predictions`
            # shape: (group_size, 1, target_embedding_dim)
            previous_steps_predictions = last_predictions_embeddings
        else:
            # shape: (group_size, steps_count, target_embedding_dim)
            previous_steps_predictions = torch.cat(
                [previous_steps_predictions, last_predictions_embeddings], 1)

        decoder_state, decoder_output = self._decoder(
            previous_state=state,
            encoder_outputs=encoder_outputs,
            source_mask=source_mask,
            previous_steps_predictions=previous_steps_predictions,
        )
        state["previous_steps_predictions"] = previous_steps_predictions

        # Update state with new decoder state, override previous state
        state.update(decoder_state)

        if self._decoder.decodes_parallel:
            decoder_output = decoder_output[:, -1, :]

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_output)

        return output_projections, state

    def _get_loss(self, logits: torch.FloatTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        """
        Compute loss.
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.
        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.
        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous().to(logits.device)

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous().to(logits.device)

        return util.sequence_cross_entropy_with_logits(logits,
                                                       relevant_targets,
                                                       relevant_mask)

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not self.training:
            all_metrics.update(
                self._token_based_metric.get_metric(reset=reset))
        return all_metrics
Beispiel #34
0
class Event2Mind(Model):
    """
    This ``Event2Mind`` class is a :class:`Model` which takes an event
    sequence, encodes it, and then uses the encoded representation to decode
    several mental state sequences.

    It is based on `the paper by Rashkin et al.
    <https://www.semanticscholar.org/paper/Event2Mind/b89f8a9b2192a8f2018eead6b135ed30a1f2144d>`_

    Parameters
    ----------
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (``tokens``) or the target tokens can have a different namespace, in which case it needs to
        be specified as ``target_namespace``.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences.
    embedding_dropout: float, required
        The amount of dropout to apply after the source tokens have been embedded.
    encoder : ``Seq2VecEncoder``, required
        The encoder of the "encoder/decoder" model.
    max_decoding_steps : int, required
        Length of decoded sequences.
    beam_size : int, optional (default = 10)
        The width of the beam search.
    target_names: ``List[str]``, optional, (default = ['xintent', 'xreact', 'oreact'])
        Names of the target fields matching those in the ``Instance`` objects.
    target_namespace : str, optional (default = 'tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : int, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 embedding_dropout: float,
                 encoder: Seq2VecEncoder,
                 max_decoding_steps: int,
                 beam_size: int = 10,
                 target_names: List[str] = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None) -> None:
        super().__init__(vocab)
        target_names = target_names or ["xintent", "xreact", "oreact"]

        # Note: The original tweaks the embeddings for "personx" to be the mean
        # across the embeddings for "he", "she", "him" and "her". Similarly for
        # "personx's" and so forth. We could consider that here as a well.
        self._source_embedder = source_embedder
        self._embedding_dropout = nn.Dropout(embedding_dropout)
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
        # Warning: The different decoders share a vocabulary! This may be
        # counterintuitive, but consider the case of xreact and oreact. A
        # reaction of "happy" could easily apply to both the subject of the
        # event and others. This could become less appropriate as more decoders
        # are added.
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with that of the final hidden states of the encoder.
        self._decoder_output_dim = self._encoder.get_output_dim()
        target_embedding_dim = target_embedding_dim or self._source_embedder.get_output_dim()

        self._states = ModuleDict()
        for name in target_names:
            self._states[name] = StateDecoder(
                    num_classes,
                    target_embedding_dim,
                    self._decoder_output_dim
            )

        self._beam_search = BeamSearch(
                self._end_index,
                beam_size=beam_size,
                max_steps=max_decoding_steps
        )

    def _update_recall(self,
                       all_top_k_predictions: torch.Tensor,
                       target_tokens: Dict[str, torch.LongTensor],
                       target_recall: UnigramRecall) -> None:
        targets = target_tokens["tokens"]
        target_mask = get_text_field_mask(target_tokens)
        # See comment in _get_loss.
        # TODO(brendanr): Do we need contiguous here?
        relevant_targets = targets[:, 1:].contiguous()
        relevant_mask = target_mask[:, 1:].contiguous()
        target_recall(
                all_top_k_predictions,
                relevant_targets,
                relevant_mask,
                self._end_index
        )

    def _get_num_decoding_steps(self,
                                target_tokens: Optional[Dict[str, torch.LongTensor]]) -> int:
        if target_tokens:
            targets = target_tokens["tokens"]
            target_sequence_length = targets.size()[1]
            # The last input from the target is either padding or the end
            # symbol.  Either way, we don't have to process it. (To be clear,
            # we do still output and compare against the end symbol, but there
            # is no need to take the end symbol as input to the decoder.)
            return target_sequence_length - 1
        else:
            return self._max_decoding_steps

    @overrides
    def forward(self,  # type: ignore
                source: Dict[str, torch.LongTensor],
                **target_tokens: Dict[str, Dict[str, torch.LongTensor]]) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Decoder logic for producing the target sequences.

        Parameters
        ----------
        source : ``Dict[str, torch.LongTensor]``
            The output of ``TextField.as_array()`` applied on the source
            ``TextField``. This will be passed through a ``TextFieldEmbedder``
            and then through an encoder.
        target_tokens : ``Dict[str, Dict[str, torch.LongTensor]]``:
            Dictionary from name to output of ``Textfield.as_array()`` applied
            on target ``TextField``. We assume that the target tokens are also
            represented as a ``TextField``.
        """
        # (batch_size, input_sequence_length, embedding_dim)
        embedded_input = self._embedding_dropout(self._source_embedder(source))
        source_mask = get_text_field_mask(source)
        # (batch_size, encoder_output_dim)
        final_encoder_output = self._encoder(embedded_input, source_mask)
        output_dict = {}

        # Perform greedy search so we can get the loss.
        if target_tokens:
            if target_tokens.keys() != self._states.keys():
                target_only = target_tokens.keys() - self._states.keys()
                states_only = self._states.keys() - target_tokens.keys()
                raise Exception("Mismatch between target_tokens and self._states. Keys in " +
                                f"targets only: {target_only} Keys in states only: {states_only}")
            total_loss = 0
            for name, state in self._states.items():
                loss = self.greedy_search(
                        final_encoder_output=final_encoder_output,
                        target_tokens=target_tokens[name],
                        target_embedder=state.embedder,
                        decoder_cell=state.decoder_cell,
                        output_projection_layer=state.output_projection_layer
                )
                total_loss += loss
                output_dict[f"{name}_loss"] = loss

            # Use mean loss (instead of the sum of the losses) to be comparable to the paper.
            output_dict["loss"] = total_loss / len(self._states)

        # Perform beam search to obtain the predictions.
        if not self.training:
            batch_size = final_encoder_output.size()[0]
            for name, state in self._states.items():
                start_predictions = final_encoder_output.new_full(
                        (batch_size,), fill_value=self._start_index, dtype=torch.long)
                start_state = {"decoder_hidden": final_encoder_output}

                # (batch_size, 10, num_decoding_steps)
                all_top_k_predictions, log_probabilities = self._beam_search.search(
                        start_predictions, start_state, state.take_step)

                if target_tokens:
                    self._update_recall(all_top_k_predictions, target_tokens[name], state.recall)
                output_dict[f"{name}_top_k_predictions"] = all_top_k_predictions
                output_dict[f"{name}_top_k_log_probabilities"] = log_probabilities

        return output_dict

    def greedy_search(self,
                      final_encoder_output: torch.LongTensor,
                      target_tokens: Dict[str, torch.LongTensor],
                      target_embedder: Embedding,
                      decoder_cell: GRUCell,
                      output_projection_layer: Linear) -> torch.FloatTensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the cross entropy between this sequence and ``target_tokens``.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_tokens : ``Dict[str, torch.LongTensor]``, required
            The output of ``TextField.as_array()`` applied on some target ``TextField``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._get_num_decoding_steps(target_tokens)
        targets = target_tokens["tokens"]
        decoder_hidden = final_encoder_output
        step_logits = []
        for timestep in range(num_decoding_steps):
            # See https://github.com/allenai/allennlp/issues/1134.
            input_choices = targets[:, timestep]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            # list of (batch_size, 1, num_classes)
            step_logits.append(output_projections.unsqueeze(1))
        # (batch_size, num_decoding_steps, num_classes)
        logits = torch.cat(step_logits, 1)
        target_mask = get_text_field_mask(target_tokens)
        return self._get_loss(logits, targets, target_mask)

    def greedy_predict(self,
                       final_encoder_output: torch.LongTensor,
                       target_embedder: Embedding,
                       decoder_cell: GRUCell,
                       output_projection_layer: Linear) -> torch.Tensor:
        """
        Greedily produces a sequence using the provided ``decoder_cell``.
        Returns the predicted sequence.

        Parameters
        ----------
        final_encoder_output : ``torch.LongTensor``, required
            Vector produced by ``self._encoder``.
        target_embedder : ``Embedding``, required
            Used to embed the target tokens.
        decoder_cell: ``GRUCell``, required
            The recurrent cell used at each time step.
        output_projection_layer: ``Linear``, required
            Linear layer mapping to the desired number of classes.
        """
        num_decoding_steps = self._max_decoding_steps
        decoder_hidden = final_encoder_output
        batch_size = final_encoder_output.size()[0]
        predictions = [final_encoder_output.new_full(
                (batch_size,), fill_value=self._start_index, dtype=torch.long
        )]
        for _ in range(num_decoding_steps):
            input_choices = predictions[-1]
            decoder_input = target_embedder(input_choices)
            decoder_hidden = decoder_cell(decoder_input, decoder_hidden)
            # (batch_size, num_classes)
            output_projections = output_projection_layer(decoder_hidden)
            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, 1)
            predictions.append(predicted_classes)
        all_predictions = torch.cat([ps.unsqueeze(1) for ps in predictions], 1)
        # Drop start symbol and return.
        return all_predictions[:, 1:]

    @staticmethod
    def _get_loss(logits: torch.LongTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.FloatTensor:
        """
        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        """
        relevant_targets = targets[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()  # (batch_size, num_decoding_steps)
        loss = sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
        return loss

    def decode_all(self, predicted_indices: torch.Tensor) -> List[List[str]]:
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
                                for x in indices]
            all_predicted_tokens.append(predicted_tokens)
        return all_predicted_tokens

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, List[List[str]]]:
        """
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds fields for the tokens to the ``output_dict``.
        """
        for name in self._states:
            top_k_predicted_indices = output_dict[f"{name}_top_k_predictions"][0]
            output_dict[f"{name}_top_k_predicted_tokens"] = [self.decode_all(top_k_predicted_indices)]

        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics = {}
        # Recall@10 needs beam search which doesn't happen during training.
        if not self.training:
            for name, state in self._states.items():
                all_metrics[name] = state.recall.get_metric(reset=reset)
        return all_metrics