Esempio n. 1
 def setUp(self):
     self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0})
Esempio n. 2
class VAE(Model):
    This ``VAE`` class is a :class:`Model` which implements a simple VAE as first described
    in (Bowman et al., 2015).
    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    variational_encoder : ``VariationalEncoder``, required
        The encoder model of which to pass the source tokens
    decoder : ``Model``, required
        The variational decoder model of which to pass the the latent variable
    latent_dim : ``int``, required
        The dimention of the latent, z vector. This is not necessarily the same size as the encoder
        output dim
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    def __init__(
        vocab: Vocabulary,
        variational_encoder: VariationalEncoder,
        decoder: Decoder,
        kl_weight: LossWeight,
        temperature: float = 1.0,
        initializer: InitializerApplicator = InitializerApplicator()
    ) -> None:
        super(VAE, self).__init__(vocab)

        self._encoder = variational_encoder
        self._decoder = decoder

        self._latent_dim = variational_encoder.latent_dim

        self._encoder_output_dim = self._encoder.get_encoder_output_dim()

        self._start_index = self.vocab.get_token_index(START_SYMBOL)
        self._end_index = self.vocab.get_token_index(END_SYMBOL)
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token)  # pylint: disable=protected-access
        self._bleu = BLEU(exclude_indices={
            self._pad_index, self._end_index, self._start_index

        self._kl_metric = Average()
        self.kl_weight = kl_weight

        self._temperature = temperature

    def forward(
        self,  # type: ignore
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        Make forward pass for both training/validation/test time.
        encoder_outs = self._encoder(source_tokens)
        p_z = encoder_outs['prior']
        q_z = encoder_outs['posterior']
        kl_weight = self.kl_weight.get()
            z = q_z.rsample()
            z = self._encoder.reparametrize(p_z, q_z, self._temperature)

        batch_size = z.size(0)
        kld = kl_divergence(q_z, p_z).sum() / batch_size
        output_dict = {'z': z, 'predictions': source_tokens['tokens']}

        if not target_tokens:
            return output_dict

        # Do Decoding
        output_dict.update(self._decoder(z, target_tokens))
        rec_loss = output_dict['loss']

        kl_loss = kld * kl_weight
        output_dict['loss'] = rec_loss + kl_loss

        if not
            best_predictions = output_dict["predictions"]
            self._bleu(best_predictions, target_tokens["tokens"])

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not
        all_metrics.update({'klw': float(self.kl_weight.get())})
            {'kl': float(self._kl_metric.get_metric(reset=reset))})
        return all_metrics

    def generate(self, num_to_sample: int = 1):
        cuda_device = self._get_prediction_device()
        prior_mean = nn_util.move_to_device(
            torch.zeros((num_to_sample, self._latent_dim)), cuda_device)
        prior_stddev = torch.ones_like(prior_mean)
        prior = Normal(prior_mean, prior_stddev)
        latent = prior.sample()
        generated = self._decoder.generate(latent)

        return self.decode(generated)

    # simple_seq2seq's decode
    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        Finalize predictions.
        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                self.vocab.get_token_from_index(x) for x in indices
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict
Esempio n. 3
class BleuTest(AllenNlpTestCase):

    def setUp(self):
        self.metric = BLEU(ngram_weights=(0.5, 0.5), exclude_indices={0})

    def test_get_valid_tokens_mask(self):
        tensor = torch.tensor([[1, 2, 3, 0],
                               [0, 1, 1, 0]])
        result = self.metric._get_valid_tokens_mask(tensor)
        result = result.long().numpy()
        check = np.array([[1, 1, 1, 0],
                          [0, 1, 1, 0]])
        np.testing.assert_array_equal(result, check)

    def test_ngrams(self):
        tensor = torch.tensor([1, 2, 3, 1, 2, 0])

        # Unigrams.
        counts = Counter(self.metric._ngrams(tensor, 1))
        unigram_check = {(1,): 2, (2,): 2, (3,): 1}
        assert counts == unigram_check

        # Bigrams.
        counts = Counter(self.metric._ngrams(tensor, 2))
        bigram_check = {(1, 2): 2, (2, 3): 1, (3, 1): 1}
        assert counts == bigram_check

        # Trigrams.
        counts = Counter(self.metric._ngrams(tensor, 3))
        trigram_check = {(1, 2, 3): 1, (2, 3, 1): 1, (3, 1, 2): 1}
        assert counts == trigram_check

        # ngram size too big, no ngrams produced.
        counts = Counter(self.metric._ngrams(tensor, 7))
        assert counts == {}

    def test_bleu_computed_correctly(self):

        # shape: (batch_size, max_sequence_length)
        predictions = torch.tensor([[1, 0, 0],
                                    [1, 1, 0],
                                    [1, 1, 1]])

        # shape: (batch_size, max_gold_sequence_length)
        gold_targets = torch.tensor([[2, 0, 0],
                                     [1, 0, 0],
                                     [1, 1, 2]])

        self.metric(predictions, gold_targets)

        assert self.metric._prediction_lengths == 6
        assert self.metric._reference_lengths == 5

        # Number of unigrams in predicted sentences that match gold sentences
        # (but not more than maximum occurence of gold unigram within batch).
        assert self.metric._precision_matches[1] == (
                0 +  # no matches in first sentence.
                1 +  # one clipped match in second sentence.
                2    # two clipped matches in third sentence.

        # Total number of predicted unigrams.
        assert self.metric._precision_totals[1] == (
                1 +
                2 +

        # Number of bigrams in predicted sentences that match gold sentences
        # (but not more than maximum occurence of gold bigram within batch).
        assert self.metric._precision_matches[2] == (
                0 +
                0 +

        # Total number of predicted bigrams.
        assert self.metric._precision_totals[2] == (
                0 +
                1 +

        # Brevity penalty should be 1.0
        assert self.metric._get_brevity_penalty() == 1.0

        bleu = self.metric.get_metric(reset=True)["BLEU"]
        check = math.exp(0.5 * (math.log(3) - math.log(6)) +
                         0.5 * (math.log(1) - math.log(3)))
        np.testing.assert_approx_equal(bleu, check)

    def test_bleu_computed_with_zero_counts(self):
        assert self.metric.get_metric()["BLEU"] == 0
Esempio n. 4
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention,
                 schema_path: str = None,
                 missing_alignment_int: int = 0,
                 indexfield_padding_index: int = -1,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.0,
                 dec_dropout: float = 0.0,
                 attn_loss_lambda: float = 0.5,
                 token_based_metric: Metric = None) -> None:
        super(AttnSupSeq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._indexfield_padding_index = indexfield_padding_index
        self._missing_alignment_int = missing_alignment_int

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
        self._end_index = self.vocab.get_token_index(END_SYMBOL,

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            self._bleu = None

        if token_based_metric:
            self._token_based_metric = token_based_metric
            self._token_based_metric = TokenSequenceAccuracy()
        # log attention supervision CE loss as a metric
        self._attn_sup_loss = Average()
        self._sql_metrics = schema_path is not None
        if self._sql_metrics:
            # SQL specific metrics: match between the templates free of schema constants,
            # and match between the schema constants
            self._schema_free_match = GlobalTemplAccuracy(
            self._kb_match = KnowledgeBaseConstsAccuracy(

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)
        self._dec_dropout = Dropout(p=dec_dropout)
        self._attn_loss_lambda = attn_loss_lambda
        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention
        self._attention._normalize = False

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        # A weighted average over encoder outputs will be concatenated to the previous target embedding
        # to form the input to the decoder at each time step.
        self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,
Esempio n. 5
class Bart(Model):
    BART model from the paper "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation,
    Translation, and Comprehension" ( The Bart model here uses a language
    modeling head and thus can be used for text generation.

    # Parameters

    model_name : `str`, required
        Name of the pre-trained BART model to use. Available options can be found in
    vocab : `Vocabulary`, required
        Vocabulary containing source and target vocabularies.
    beam_search : `Lazy[BeamSearch]`, optional (default = `Lazy(BeamSearch)`)
        This is used to during inference to select the tokens of the decoded output sequence.
    indexer : `PretrainedTransformerIndexer`, optional (default = `None`)
        Indexer to be used for converting decoded sequences of ids to to sequences of tokens.
    encoder : `Seq2SeqEncoder`, optional (default = `None`)
        Encoder to used in BART. By default, the original BART encoder is used.

    def __init__(
        model_name: str,
        vocab: Vocabulary,
        beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
        indexer: PretrainedTransformerIndexer = None,
        encoder: Seq2SeqEncoder = None,
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        # At prediction time, we'll use a beam search to find the best target sequence.
        # For backwards compatibility, check if beam_size or max_decoding_steps were passed in as
        # kwargs. If so, update the BeamSearch object before constructing and raise a DeprecationWarning
        deprecation_warning = (
            "The parameter {} has been deprecated."
            " Provide this parameter as argument to beam_search instead."
        beam_search_extras = {}
        if "beam_size" in kwargs:
            beam_search_extras["beam_size"] = kwargs["beam_size"]
            warnings.warn(deprecation_warning.format("beam_size"), DeprecationWarning)
        if "max_decoding_steps" in kwargs:
            beam_search_extras["max_steps"] = kwargs["max_decoding_steps"]
            warnings.warn(deprecation_warning.format("max_decoding_steps"), DeprecationWarning)
        self._beam_search = beam_search.construct(
            end_index=self._end_id, vocab=self.vocab, **beam_search_extras

        self._rouge = ROUGE(exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (
                encoder.get_input_dim() == encoder.get_output_dim() == self.bart.config.hidden_size
            self.bart.model.encoder = _BartEncoderWrapper(

    def forward(
        self, source_tokens: TextFieldTensors, target_tokens: TextFieldTensors = None
    ) -> Dict[str, torch.Tensor]:
        Performs the forward step of Bart.

        # Parameters

        source_tokens : `TextFieldTensors`, required
            The source tokens for the encoder. We assume they are stored under the `tokens` key.
        target_tokens : `TextFieldTensors`, optional (default = `None`)
            The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target
            tokens are given, the source tokens are shifted to the right by 1.

        # Returns

        `Dict[str, torch.Tensor]`
            During training, this dictionary contains the `decoder_logits` of shape `(batch_size,
            max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions`
            of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`.

        inputs = source_tokens
        targets = target_tokens
        input_ids, input_mask = inputs["tokens"]["token_ids"], inputs["tokens"]["mask"]

        outputs = {}

        # If no targets are provided, then shift input to right by 1. Bart already does this internally
        # but it does not use them for loss calculation.
        if targets is not None:
            target_ids, target_mask = targets["tokens"]["token_ids"], targets["tokens"]["mask"]
            target_ids = input_ids[:, 1:]
            target_mask = input_mask[:, 1:]

            bart_outputs = self.bart(
                decoder_input_ids=target_ids[:, :-1].contiguous(),
                decoder_attention_mask=target_mask[:, :-1].contiguous(),
            outputs["decoder_logits"] = bart_outputs.logits

            # The BART paper mentions label smoothing of 0.1 for sequence generation tasks
            outputs["loss"] = sequence_cross_entropy_with_logits(
                cast(torch.LongTensor, target_ids[:, 1:].contiguous()),
                cast(torch.BoolTensor, target_mask[:, 1:].contiguous()),
            # Use decoder start id and start of sentence to start decoder
            initial_decoder_ids = torch.tensor(
            ).repeat(input_ids.shape[0], 1)

            inital_state = {
                "input_ids": input_ids,
                "input_mask": input_mask,
            beam_result =
                initial_decoder_ids, inital_state, self.take_step

            predictions = beam_result[0]
            max_pred_indices = (
                beam_result[1].argmax(dim=-1).view(-1, 1, 1).expand(-1, -1, predictions.shape[-1])
            predictions = predictions.gather(dim=1, index=max_pred_indices).squeeze(dim=1)

            self._rouge(predictions, target_ids)
            self._bleu(predictions, target_ids)

            outputs["predictions"] = predictions
            outputs["log_probabilities"] = (
                beam_result[1].gather(dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1)


        return outputs

    def _decoder_cache_to_dict(decoder_cache: DecoderCacheType) -> Dict[str, torch.Tensor]:
        cache_dict = {}
        for layer_index, layer_cache in enumerate(decoder_cache):
            # Each layer caches the key and value tensors for its self-attention and cross-attention.
            # Hence the `layer_cache` tuple has 4 elements.
            assert len(layer_cache) == 4
            for tensor_index, tensor in enumerate(layer_cache):
                key = f"decoder_cache_{layer_index}_{tensor_index}"
                cache_dict[key] = tensor
        return cache_dict

    def _dict_to_decoder_cache(self, cache_dict: Dict[str, torch.Tensor]) -> DecoderCacheType:
        decoder_cache = []
        for layer_index in range(len(self.bart.model.decoder.layers)):
            base_key = f"decoder_cache_{layer_index}_"
            layer_cache = (
                cache_dict[base_key + "0"],
                cache_dict[base_key + "1"],
                cache_dict[base_key + "2"],
                cache_dict[base_key + "3"],
        assert decoder_cache
        return tuple(decoder_cache)

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], step: int
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        Take step during beam search.

        # Parameters

        last_predictions : `torch.Tensor`
            The predicted token ids from the previous step. Shape: `(group_size,)`
        state : `Dict[str, torch.Tensor]`
            State required to generate next set of predictions
        step : `int`
            The time step in beam search decoding.

        # Returns

        `Tuple[torch.Tensor, Dict[str, torch.Tensor]]`
            A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and
            an updated state dictionary.
        if len(last_predictions.shape) == 1:
            last_predictions = last_predictions.unsqueeze(-1)

        decoder_cache = None
        decoder_cache_dict = {
            k: state[k].contiguous()
            for k in state
            if k not in {"input_ids", "input_mask", "encoder_states"}
        if len(decoder_cache_dict) != 0:
            decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict)

        encoder_outputs = (state["encoder_states"],) if "encoder_states" in state else None
        outputs = self.bart(
            input_ids=state["input_ids"] if encoder_outputs is None else None,

        logits = outputs.logits[:, -1, :]
        log_probabilities = F.log_softmax(logits, dim=-1)

        decoder_cache = outputs.past_key_values
        if decoder_cache is not None:
            decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache)

        state["encoder_states"] = outputs.encoder_last_hidden_state

        return log_probabilities, state

    def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:

        # Parameters

        output_dict : `Dict[str, torch.Tensor]`
            A dictionary containing a batch of predictions with key `predictions`. The tensor should have
            shape `(batch_size, max_sequence_length)`

        # Returns

        `Dict[str, Any]`
            Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of

        predictions = output_dict["predictions"]
        predicted_tokens = [None] * predictions.shape[0]
        for i in range(predictions.shape[0]):
            predicted_tokens[i] = self._indexer.indices_to_tokens(
                {"token_ids": predictions[i].tolist()},
        output_dict["predicted_tokens"] = predicted_tokens  # type: ignore
        output_dict["predicted_text"] = self._indexer._tokenizer.batch_decode(
            predictions.tolist(), skip_special_tokens=True

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        if not
        return metrics

    default_predictor = "seq2seq"
Esempio n. 6
class MaskedCopyNet(Model):

    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention = None,
                 mask_embedder: TextFieldEmbedder = None,
                 mask_attention: Attention = None,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True) -> None:
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
        self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace)
            self._bleu = BLEU(exclude_indices={pad_index, self._end_index, self._start_index})
            self._bleu = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)

        # Dense embedding of source vocab tokens.
        self._embedder = embedder
        self._mask_embedder = mask_embedder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention
        self._mask_attention = mask_attention

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = self._embedder.get_output_dim()

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will be concatenated
            # to the previous target embedding to form the input to the decoder at each
            # time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        if self._mask_attention:
            self._decoder_input_dim += self._mask_embedder.get_output_dim()

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)

    def take_step(self,
                  last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    def forward(self,  # type: ignore
                source_tokens: Dict[str, torch.LongTensor],
                target_tokens: Dict[str, torch.LongTensor] = None,
                mask_tokens: Dict[str, torch.LongTensor] = None,
                **kwargs) -> Dict[str, torch.Tensor]:
        del kwargs
        assert mask_tokens is not None or self._mask_embedder is None, \
            'You must pass `mask_tokens` when `mask_embedder` is not None'
        state = self.encode(source_tokens, mask_tokens)

        if target_tokens:
            state = self.init_decoder_state(state)
            output_dict = self._forward_loop(state, target_tokens)
            output_dict = {}

        if not
            state = self.init_decoder_state(state)
            predictions = self.beam_search(state)
            if target_tokens and self._bleu:
                # shape: (batch_size, beam_size, max_sequence_length)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_predicted_sequence_length)
                best_predictions = top_k_predictions[:, 0, :]
                self._bleu(best_predictions, target_tokens["tokens"])

        return output_dict

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for i, indices in enumerate(predicted_indices):
            curr_predictions = []
            for ind in indices:
                ind = list(ind)
                # Collect indices till the first end_symbol
                if self._end_index in ind:
                    ind = ind[:ind.index(self._end_index)]
                predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
                                    for x in ind]
        output_dict["predicted_tokens"] = all_predicted_tokens  # [batch_size, k, num_decoding_steps]
        return output_dict

    def encode(self, source_tokens: Dict[str, torch.Tensor],
               mask_tokens: Dict[str, torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        state = {
            "source_mask": source_mask,
            "encoder_outputs": encoder_outputs

        if mask_tokens is not None and self._mask_embedder is not None:
            embedded_input = self._mask_embedder(mask_tokens)
            masker_mask = util.get_text_field_mask(mask_tokens)
                    "mask_source_mask": masker_mask,
                    "mask_encoder_outputs": embedded_input
        return state

    def init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim)
        return state

    def _forward_loop(self,
                      state: Dict[str, torch.Tensor],
                      target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index)

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if and torch.rand(1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes


        # shape: (batch_size, num_decoding_steps)
        predictions =, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits =, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)

        all_top_k_predictions, log_probabilities =
                start_predictions, state, self.take_step)

        output_dict = {
                "class_log_probabilities": log_probabilities,
                "predictions": all_top_k_predictions,
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._embedder({self._target_namespace: last_predictions})

        if self._attention:
            # shape: (group_size, encoder_output_dim)
            attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask)

            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            decoder_input =, embedded_input), -1)
            # shape: (group_size, target_embedding_dim)
            decoder_input = embedded_input

        if self._mask_attention and self._mask_embedder:
            mask_encoder_outputs = state["mask_encoder_outputs"]
            mask_source_mask = state["mask_source_mask"]
            mask_attended_input = self._prepare_mask_attended_input(
            decoder_input =, mask_attended_input), -1)

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
                (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_hidden)

        return output_projections, state

    def _prepare_attended_input(self,
                                decoder_hidden_state: torch.LongTensor = None,
                                encoder_outputs: torch.LongTensor = None,
                                encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        encoder_outputs_mask = encoder_outputs_mask.float()
        input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)
        return attended_input

    def _prepare_mask_attended_input(self,
                                     decoder_hidden_state: torch.LongTensor = None,
                                     mask_encoder_outputs: torch.LongTensor = None,
                                     mask_encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
        encoder_outputs_mask = mask_encoder_outputs_mask.float()
        input_weights = self._mask_attention(decoder_hidden_state, mask_encoder_outputs, encoder_outputs_mask)
        attended_input = util.weighted_sum(mask_encoder_outputs, input_weights)
        return attended_input

    def _get_loss(logits: torch.LongTensor,
                  targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not
        return all_metrics
Esempio n. 7
class AttnSupSeq2Seq(Model):
    Adaptation of the ``SimpleSeq2Seq`` class in allennlp_models, with auxiliary attention-supervision loss

    vocab : ``Vocabulary``, required
        Vocabulary containing source and target vocabularies. They may be under the same namespace
        (`tokens`) or the target tokens can have a different namespace, in which case it needs to
        be specified as `target_namespace`.
    source_embedder : ``TextFieldEmbedder``, required
        Embedder for source side sequences
    encoder : ``Seq2SeqEncoder``, required
        The encoder of the "encoder/decoder" model
    max_decoding_steps : ``int``
        Maximum length of decoded sequences.
    target_namespace : ``str``, optional (default = 'target_tokens')
        If the target side vocabulary is different from the source side's, you need to specify the
        target's namespace here. If not, we'll assume it is "tokens", which is also the default
        choice for the source side, and this might cause them to share vocabularies.
    target_embedding_dim : ``int``, optional (default = source_embedding_dim)
        You can specify an embedding dimensionality for the target side. If not, we'll use the same
        value as the source embedder's.
    attention : ``Attention``, optional (default = None)
        If you want to use attention to get a dynamic summary of the encoder outputs at each step
        of decoding, this is the function used to compute similarity between the decoder hidden
        state and encoder outputs.
    attention_function: ``SimilarityFunction``, optional (default = None)
        This is if you want to use the legacy implementation of attention. This will be deprecated
        since it consumes more memory than the specialized attention modules.
    beam_size : ``int``, optional (default = None)
        Width of the beam for beam search. If not specified, greedy decoding is used.
    scheduled_sampling_ratio : ``float``, optional (default = 0.)
        At each timestep during training, we sample a random number between 0 and 1, and if it is
        not less than this value, we use the ground truth labels for the whole batch. Else, we use
        the predictions from the previous time step for the whole batch. If this value is 0.0
        (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
        using target side ground truth labels.  See the following paper for more information:
        `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
        2015 <>`_.
    use_bleu : ``bool``, optional (default = True)
        If True, the BLEU metric will be calculated during validation.
    def __init__(self,
                 vocab: Vocabulary,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 max_decoding_steps: int,
                 attention: Attention,
                 schema_path: str = None,
                 missing_alignment_int: int = 0,
                 indexfield_padding_index: int = -1,
                 beam_size: int = None,
                 target_namespace: str = "tokens",
                 target_embedding_dim: int = None,
                 scheduled_sampling_ratio: float = 0.,
                 use_bleu: bool = True,
                 emb_dropout: float = 0.0,
                 dec_dropout: float = 0.0,
                 attn_loss_lambda: float = 0.5,
                 token_based_metric: Metric = None) -> None:
        super(AttnSupSeq2Seq, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio
        self._indexfield_padding_index = indexfield_padding_index
        self._missing_alignment_int = missing_alignment_int

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
        self._end_index = self.vocab.get_token_index(END_SYMBOL,

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            self._bleu = None

        if token_based_metric:
            self._token_based_metric = token_based_metric
            self._token_based_metric = TokenSequenceAccuracy()
        # log attention supervision CE loss as a metric
        self._attn_sup_loss = Average()
        self._sql_metrics = schema_path is not None
        if self._sql_metrics:
            # SQL specific metrics: match between the templates free of schema constants,
            # and match between the schema constants
            self._schema_free_match = GlobalTemplAccuracy(
            self._kb_match = KnowledgeBaseConstsAccuracy(

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder
        self._emb_dropout = Dropout(p=emb_dropout)
        self._dec_dropout = Dropout(p=dec_dropout)
        self._attn_loss_lambda = attn_loss_lambda
        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._encoder = encoder

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        self._attention = attention
        self._attention._normalize = False

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        self._encoder_output_dim = self._encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim

        # A weighted average over encoder outputs will be concatenated to the previous target embedding
        # to form the input to the decoder at each time step.
        self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        Take a decoding step. This is called by the beam search class.

        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        # shape: (group_size, num_classes)
        _, output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    def forward_on_instances(
            self, instances: List[Instance]) -> List[Dict[str, numpy.ndarray]]:
        Takes a list of  :class:``s, converts that text into
        arrays using this model's :class:`Vocabulary`, passes those arrays through
        :func:`self.forward()` and :func:`self.decode()` (which by default does nothing)
        and returns the result.  Before returning the result, we convert any
        ``torch.Tensors`` into numpy arrays and separate the
        batched output into a list of individual dicts per instance. Note that typically
        this will be faster on a GPU (and conditionally, on a CPU) than repeated calls to

        instances : List[Instance], required
            The instances to run the model on.
        cuda_device : int, required
            The GPU device to use.  -1 means use the CPU.

        A list of the models output for each instance.
        batch_size = len(instances)
        with torch.no_grad():
            cuda_device = self._get_prediction_device()
            dataset = Batch(instances)
            model_input = util.move_to_device(dataset.as_tensor_dict(),
            outputs = self.decode(self(**model_input))

            instance_separated_output: List[Dict[str, numpy.ndarray]] = [
                {} for _ in dataset.instances
            for name, output in list(outputs.items()):
                if isinstance(output, torch.Tensor):
                    # NOTE(markn): This is a hack because 0-dim pytorch tensors are not iterable.
                    # This occurs with batch size 1, because we still want to include the loss in that case.
                    if output.dim() == 0:
                        output = output.unsqueeze(0)

                    if output.size(0) != batch_size:
                    output = output.detach().cpu().numpy()
                elif len(output) != batch_size:
                for instance_output, batch_element in zip(
                        instance_separated_output, output):
                    instance_output[name] = batch_element

            for instance_output, instance_input in zip(
                    instance_separated_output, instances):
                for field in instance_input.fields:
                        instance_output[field] = instance_input.fields[
                    except Exception as e:

            return instance_separated_output

    def forward(
            self,  # type: ignore
            source_tokens: Dict[str, torch.LongTensor],
            target_tokens: Dict[str, torch.LongTensor] = None,
            alignment_sequence: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        Make foward pass with decoder logic for producing the entire target sequence.

        source_tokens : ``Dict[str, torch.LongTensor]``
           The output of `TextField.as_array()` applied on the source `TextField`. This will be
           passed through a `TextFieldEmbedder` and then through an encoder.
        target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
           Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
           target tokens are also represented as a `TextField`.
        alignment_sequence : ``Dict[str, torch.LongTensor]``, optional (default = None)
            Output of `Textfield.as_array()` applied on alignemnet `TextField`.
        Dict[str, torch.Tensor]
        state = self._encode(source_tokens)

        if target_tokens:
            state = self._init_decoder_state(state)

            # Remove the trailing dimension (from ListField[ListField[IndexField]]).
            alignment_sequence = alignment_sequence.squeeze(-1)

            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, target_tokens,
            output_dict = {}

        if not
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            if target_tokens:
                if self._bleu:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]
                    self._bleu(best_predictions, target_tokens["tokens"])

                predicted_tokens = self.decode(output_dict)["predicted_tokens"]
                target_tokens_str = self.decode_target_tokens(target_tokens)

                if self._token_based_metric:
                if self._sql_metrics:
                    self._kb_match(predicted_tokens, target_tokens_str)

        # In case of attention coverage mechanism, reset the coverage vector after every batch...
        except Exception:

        return output_dict

    def decode_target_tokens(self, target_tokens):
        target_indices = target_tokens['tokens'].detach().cpu().numpy()
        target_tokens_output = []
        for i in range(target_indices.shape[0]):
            cur_target_indices = target_indices[i]
            cur_target_indices = list(cur_target_indices)
            if self._end_index in cur_target_indices:
                cur_target_indices = cur_target_indices[:cur_target_indices.
            if self._start_index in cur_target_indices:
                cur_target_indices = cur_target_indices[
                    cur_target_indices.index(self._start_index) + 1:]
            target_tokens_str = [
                    x, namespace=self._target_namespace)
                for x in cur_target_indices

        return target_tokens_output

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        Finalize predictions.

        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                    x, namespace=self._target_namespace) for x in indices
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._encoder(embedded_input, source_mask)
        encoder_outputs = self._emb_dropout(encoder_outputs)
        return {
            "source_mask": source_mask,
            "encoder_outputs": encoder_outputs,

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"], state["source_mask"],
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(
            batch_size, self._decoder_output_dim)
        return state

    def _forward_loop(
            state: Dict[str, torch.Tensor],
            target_tokens: Dict[str, torch.LongTensor] = None,
            alignment_sequence: torch.Tensor = None
    ) -> Dict[str, torch.Tensor]:
        Make forward pass during training or do greedy search during prediction.

        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        step_attn_weights: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            # shape: (batch_size, input_max_size)
            input_weights, output_projections, state = self._prepare_output_projections(
                input_choices, state)


            # list of tensors, shape: (batch_size, 1, num_classes)

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes


        # shape: (batch_size, num_decoding_steps)
        predictions =, 1)

        # shape: (batch_size, num_decoding_steps, max_input_sequence_length)
        attention_input_weights =[:-1], 1)

        output_dict = {
            "predictions": predictions,
            'attention_input_weights': attention_input_weights

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits =, 1)

            # shape: (batch_size, num_decoding_steps, max_input_sequence_length)
            alignment_mask = self._get_alignment_mask(alignment_sequence)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)

            attn_sup_loss = self._get_attn_sup_loss(attention_input_weights,

            output_dict["loss"] = loss + self._attn_loss_lambda * attn_sup_loss

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities =
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        # shape: (group_size, encoder_output_dim)
        attended_input, input_weights = self._prepare_attended_input(
            decoder_hidden, encoder_outputs, source_mask)

        # shape: (group_size, decoder_output_dim + target_embedding_dim)
        decoder_input =, embedded_input), -1)
        decoder_input = self._dec_dropout(decoder_input)

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input, (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(

        return input_weights, output_projections, state

    def _prepare_attended_input(
        decoder_hidden_state: torch.LongTensor = None,
        encoder_outputs: torch.LongTensor = None,
        encoder_outputs_mask: torch.LongTensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()
        # shape: (batch_size, max_input_sequence_length)
        input_logits = self._attention(decoder_hidden_state, encoder_outputs,
        # the attention mechanism returns the logits that are necessary for attention supervision loss,
        # so we normalize it here
        input_weights = masked_softmax(input_logits, encoder_outputs_mask)
        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input, input_logits

    def _get_attn_sup_loss(attn_weights: torch.Tensor,
                           alignment_mask: torch.Tensor,
                           alignment_sequence: torch.Tensor) -> torch.Tensor:
        Compute the attention supervision CE loss.
        For each step, take the index of the aligned
        # shape: (batch_size, max_decoding_steps, max_input_seq_length
        attn_weights = attn_weights.float()

        alignment_sequence[alignment_sequence == -1] = 0
        # for each attn_weights[batch_index, step_index, :] I want to choose the index of
        # alignment_sequence[batch_index, step_index]
        return util.sequence_cross_entropy_with_logits(attn_weights,

    def _get_alignment_mask(self, alignment_sequence):
        The alignment mask includes the target mask + mask on steps that don't have alignment
        shape: batch_size, max_steps, max_input
        pad_mask = alignment_sequence != self._indexfield_padding_index
        missing_mask = alignment_sequence != self._missing_alignment_int

        return pad_mask * missing_mask

    def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits,

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not
            if self._bleu:
            if self._sql_metrics:
            all_metrics['attn_sup_loss'] = self._attn_sup_loss.get_metric(
        return all_metrics
Esempio n. 8
class MyTransformer(Model):
    def __init__(
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        transformer: Dict,
        max_decoding_steps: int,
        target_namespace: str,
        target_embedder: TextFieldEmbedder = None,
        use_bleu: bool = True,
    ) -> None:
        self._target_namespace = target_namespace

        self._start_index = self.vocab.get_token_index(START_SYMBOL,
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,

        if use_bleu:
            self._bleu = BLEU(exclude_indices={
                self._pad_index, self._end_index, self._start_index
            self._bleu = None
        self._seq_acc = SequenceAccuracy()

        self._max_decoding_steps = max_decoding_steps

        self._source_embedder = source_embedder

        self._ndim = transformer["d_model"]
        self.pos_encoder = PositionalEncoding(self._ndim,

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        self._transformer = Transformer(**transformer)

        if target_embedder is None:
            self._target_embedder = self._source_embedder
            self._target_embedder = target_embedder

        self._output_projection_layer = Linear(self._ndim, num_classes)

    def _get_mask(self, meta_data):
        mask = torch.zeros(1, len(meta_data),
        for bidx, md in enumerate(meta_data):
            for k, v in self.vocab._token_to_index[
                if 'position' in k and k not in md['avail_pos']:
                    mask[:, bidx, v] = float('-inf')
        return mask

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == False,
                                            mask == True, float(0.0))
        return mask

    def forward(
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None,
        meta_data: Any = None,
    ) -> Dict[str, torch.Tensor]:
        src, src_key_padding_mask = self._encode(self._source_embedder,
        memory = self._transformer.encoder(
            src, src_key_padding_mask=src_key_padding_mask)

        if meta_data is not None:
            target_vocab_mask = self._get_mask(meta_data)
            target_vocab_mask =
            target_vocab_mask = None
        output_dict = {}
        targets = None
        if target_tokens:
            targets = target_tokens["tokens"][:, 1:]
            target_mask = (util.get_text_field_mask({"tokens": targets}) == 1)
            assert targets.size(1) <= self._max_decoding_steps
        if and target_tokens:
            tgt, tgt_key_padding_mask = self._encode(
                {"tokens": target_tokens["tokens"][:, :-1]})
            tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(
            output = self._transformer.decoder(
            logits = self._output_projection_layer(output)
            if target_vocab_mask is not None:
                logits += target_vocab_mask
            class_probabilities = F.softmax(logits.detach(), dim=-1)
            _, predictions = torch.max(class_probabilities, -1)
            logits = logits.transpose(0, 1)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss
            assert is False
            output_dict["loss"] = torch.tensor(0.0).to(memory.device)
            if targets is not None:
                max_target_len = targets.size(1)
                max_target_len = None
            predictions, class_probabilities = self._decoder_step_by_step(
        predictions = predictions.transpose(0, 1)
        output_dict["predictions"] = predictions
        output_dict["class_probabilities"] = class_probabilities.transpose(
            0, 1)

        if target_tokens:
            with torch.no_grad():
                best_predictions = output_dict["predictions"]
                if self._bleu:
                    self._bleu(best_predictions, targets)
                batch_size = targets.size(0)
                max_sz = max(best_predictions.size(1), targets.size(1),
                best_predictions_ = torch.zeros(batch_size,
                best_predictions_[:, :best_predictions.
                                  size(1)] = best_predictions
                targets_ = torch.zeros(batch_size, max_sz).to(memory.device)
                targets_[:, :targets.size(1)] = targets.cpu()
                target_mask_ = torch.zeros(batch_size,
                target_mask_[:, :target_mask.size(1)] = target_mask
                self._seq_acc(best_predictions_.unsqueeze(1), targets_,
        return output_dict

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            # shape: (batch_size, num_decoding_steps)
            predicted_indices = predicted_indices.detach().cpu().numpy()
            # class_probabilities = output_dict["class_probabilities"].detach().cpu()
            # sample_predicted_indices = []
            # for cp in class_probabilities:
            #     sample = torch.multinomial(cp, num_samples=1)
            #     sample_predicted_indices.append(sample)
            # # shape: (batch_size, num_decoding_steps, num_samples)
            # sample_predicted_indices = torch.stack(sample_predicted_indices)

        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                    x, namespace=self._target_namespace) for x in indices
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
            self, embedder: TextFieldEmbedder,
            tokens: Dict[str,
                         torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
        src = embedder(tokens) * math.sqrt(self._ndim)
        src = src.transpose(0, 1)
        src = self.pos_encoder(src)
        mask = util.get_text_field_mask(tokens)
        mask = (mask == 0)
        return src, mask

    def _decoder_step_by_step(
            memory: torch.Tensor,
            memory_key_padding_mask: torch.Tensor,
            target_vocab_mask: torch.Tensor = None,
            max_target_len: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = memory.size(1)
        if getattr(self, "target_limit_decode_steps",
                   False) and max_target_len is not None:
            num_decoding_steps = min(self._max_decoding_steps, max_target_len)
            print('decoding steps: ', num_decoding_steps)
            num_decoding_steps = self._max_decoding_steps

        last_predictions = memory.new_full(
            (batch_size, ), fill_value=self._start_index).long()

        step_predictions: List[torch.Tensor] = []
        all_predicts = memory.new_full((batch_size, num_decoding_steps),
        for timestep in range(num_decoding_steps):
            all_predicts[:, timestep] = last_predictions
            tgt, tgt_key_padding_mask = self._encode(
                {"tokens": all_predicts[:, :timestep + 1]})
            tgt_mask = self.generate_square_subsequent_mask(timestep + 1).to(
            output = self._transformer.decoder(
            output_projections = self._output_projection_layer(output)
            if target_vocab_mask is not None:
                output_projections += target_vocab_mask

            class_probabilities = F.softmax(output_projections, dim=-1)
            _, predicted_classes = torch.max(class_probabilities, -1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes[timestep, :]
            if ((last_predictions == self._end_index) +
                (last_predictions == self._pad_index)).all():

        # shape: (num_decoding_steps, batch_size)
        predictions = torch.stack(step_predictions)
        return predictions, class_probabilities

    def _get_loss(logits: torch.FloatTensor, targets: torch.LongTensor,
                  target_mask: torch.FloatTensor) -> torch.Tensor:
        logits = logits.contiguous()
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets.contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask.contiguous()

        return util.sequence_cross_entropy_with_logits(logits,

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu:
        all_metrics['seq_acc'] = self._seq_acc.get_metric(reset=reset)
        return all_metrics

    def load_state_dict(self, state_dict, strict=True):
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('module.'):
                new_state_dict[k[len('module.'):]] = v
                new_state_dict[k] = v

        super(MyTransformer, self).load_state_dict(new_state_dict, strict)
class FactParaphraseSeq2Seq(Model):
    Given facts and dialog acts, it generates the paraphrased message.
    TODO: add dialog & dialog acts history

    This implementation is based off the default SimpleSeq2Seq model,
    which takes a sequence, encodes it, and then uses the encoded
    representations to decode another sequence.
    def __init__(
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        source_encoder: Seq2SeqEncoder,
        max_decoding_steps: int,
        dialog_acts_encoder: FeedForward = None,
        attention: Attention = None,
        attention_function: SimilarityFunction = None,
        n_dialog_acts: int = None,
        beam_size: int = None,
        target_namespace: str = "tokens",
        target_embedding_dim: int = None,
        scheduled_sampling_ratio: float = 0.0,
        use_bleu: bool = True,
        use_dialog_acts: bool = True,
        regularizers: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizers)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first
        # timestep of decoding, and end symbol as a way to indicate the end
        # of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
        self._end_index = self.vocab.get_token_index(END_SYMBOL,

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            self._bleu = None

        # At prediction time, we use a beam search to find the most
        # likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,

        # Dense embedding of source (Facts) vocab tokens.
        self._source_embedder = source_embedder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._source_encoder = source_encoder

        if use_dialog_acts:
            # Dense embedding of dialog acts.
            da_embedding_dim = dialog_acts_encoder.get_input_dim()
            self._dialog_acts_embedder = EmbeddingBag(n_dialog_acts,

            # Encodes dialog acts
            self._dialog_acts_encoder = dialog_acts_encoder

            self._dialog_acts_embedder = None
            self._dialog_acts_encoder = None

        num_classes = self.vocab.get_vocab_size(self._target_namespace)

        # Attention mechanism applied to the encoder output for each step.
        if attention:
            if attention_function:
                raise ConfigurationError(
                    "You can only specify an attention module or an "
                    "attention function, but not both.")
            self._attention = attention
        elif attention_function:
            self._attention = LegacyAttention(attention_function)
            self._attention = None

        # Dense embedding of vocab words in the target space.
        target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim(
        self._target_embedder = Embedding(num_classes, target_embedding_dim)

        # Decoder output dim needs to be the same as the encoder output dim
        # since we initialize the hidden state of the decoder with the final
        # hidden state of the encoder.
        self._encoder_output_dim = self._source_encoder.get_output_dim()
        if use_dialog_acts:
            self._merge_encoder = Sequential(
                    self._source_encoder.get_output_dim() +
        self._decoder_output_dim = self._encoder_output_dim

        if self._attention:
            # If using attention, a weighted average over encoder outputs will
            # be concatenated to the previous target embedding to form the input
            # to the decoder at each time step.
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
            # Otherwise, the input to the decoder is just the previous target embedding.
            self._decoder_input_dim = target_embedding_dim

        # We'll use an LSTM cell as the recurrent cell that produces a hidden state
        # for the decoder at each time step.
        # TODO (pradeep): Do not hardcode decoder cell type.
        self._decoder_cell = LSTMCell(self._decoder_input_dim,

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(self._decoder_output_dim,

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        Take a decoding step. This is called by the beam search class.
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    def forward(
        self,  # type: ignore
        source_tokens: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None,
        dialog_acts: Optional[torch.Tensor] = None,
        sender: Optional[torch.Tensor] = None,
        metadata: Optional[Dict] = None,
    ) -> Dict[str, torch.Tensor]:
        Make foward pass with decoder logic for producing the entire target sequence.
        source_state, dialog_acts_state = self._encode(source_tokens,

        if target_tokens:
            state = self._init_decoder_state(source_state, dialog_acts_state)
            # The `_forward_loop` decodes the input sequence and
            # computes the loss during training and validation.
            output_dict = self._forward_loop(state, target_tokens)
            output_dict = {}

        if not
            state = self._init_decoder_state(source_state, dialog_acts_state)
            predictions = self._forward_beam_search(state)
            if target_tokens and self._bleu:
                # shape: (batch_size, beam_size, max_sequence_length)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_predicted_sequence_length)
                best_predictions = top_k_predictions[:, 0, :]
                self._bleu(best_predictions, target_tokens["tokens"])

        return output_dict

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        Finalize predictions.
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence
            # in the batch but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                    x, namespace=self._target_namespace) for x in indices
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
        source_tokens: Dict[str, torch.Tensor],
        dialog_acts: torch.Tensor = None
    ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
        # Encode source tokens
        source_state = self._encode_source_tokens(source_tokens)

        # Encode dialog acts
        if self._dialog_acts_encoder:
            dialog_acts_state = self._encode_dialog_acts(dialog_acts)

            dialog_acts_state = None

        return (source_state, dialog_acts_state)

    def _encode_source_tokens(
            source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = self._source_encoder(embedded_input, source_mask)
        return {"source_mask": source_mask, "encoder_outputs": encoder_outputs}

    def _encode_dialog_acts(self, dialog_acts: torch.Tensor) -> torch.Tensor:
        # shape: (batch_size, dialog_acts_embeddings_size)
        embedded_dialog_acts = self._dialog_acts_embedder(dialog_acts)

        # shape: (batch_size, dim_encoder)
        dialog_acts_state = self._dialog_acts_encoder(embedded_dialog_acts)
        return dialog_acts_state

    def _init_decoder_state(
        source_state: Dict[str, torch.Tensor],
        dialog_acts_state: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        batch_size = source_state["source_mask"].size(0)

        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(

        # Condition the source tokens state with dialog acts state
        if self._dialog_acts_encoder:
            final_encoder_output = self._merge_encoder(
      [final_encoder_output, dialog_acts_state], dim=1))

        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        source_state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        source_state["decoder_context"] = source_state[
            "encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim)
        return source_state

    def _forward_loop(
        state: Dict[str, torch.Tensor],
        target_tokens: Dict[str, torch.LongTensor] = None,
    ) -> Dict[str, torch.Tensor]:
        Make forward pass during training or do greedy search during prediction.
        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of
                # 1 - _scheduled_sampling_ratio during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(
                input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes


        # shape: (batch_size, num_decoding_steps)
        predictions =, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits =, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities =
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        return output_dict

    def _prepare_output_projections(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.
        Inputs are the same as for `take_step()`.
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder(last_predictions)

        if self._attention:
            # shape: (group_size, encoder_output_dim)
            attended_input = self._prepare_attended_input(
                decoder_hidden, encoder_outputs, source_mask)

            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            decoder_input =, embedded_input), -1)
            # shape: (group_size, target_embedding_dim)
            decoder_input = embedded_input

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input, (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_hidden)

        return output_projections, state

    def _prepare_attended_input(
        decoder_hidden_state: torch.LongTensor = None,
        encoder_outputs: torch.LongTensor = None,
        encoder_outputs_mask: torch.LongTensor = None,
    ) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = self._attention(decoder_hidden_state, encoder_outputs,

        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input

    def _get_loss(
        logits: torch.LongTensor,
        targets: torch.LongTensor,
        target_mask: torch.LongTensor,
    ) -> torch.Tensor:

        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits,

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not
        return all_metrics
class LatentAignmentCTC(Model):

    def __init__(
        vocab: Vocabulary,
        source_embedder: TextFieldEmbedder,
        upsample: torch.nn.Module = None, 
        net: Seq2SeqEncoder = None,
        target_namespace: str = "target_tokens",
        target_embedding_dim: int = None,
        use_bleu: bool = True,
    ) -> None:
        super(LatentAignmentCTC, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token, 
        self._blank_index = self.vocab.get_token_index(SPECIAL_BLANK_TOKEN, 

        if use_bleu:
            self._bleu = BLEU(exclude_indices={self._pad_index, self._blank_index})
            self._bleu = None

        self._source_embedder = source_embedder
        source_embedding_dim = source_embedder.get_output_dim()

        self._upsample = upsample or LinearUpsample(source_embedding_dim, s = 3)
        self._net = net or StackedSelfAttentionEncoder(input_dim = source_embedding_dim,
                                                       hidden_dim = 128,
                                                       projection_dim = 128,
                                                       feedforward_hidden_dim = 512,
                                                       num_layers = 4,
                                                       num_attention_heads = 4)
        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        target_embedding_dim = self._net.get_output_dim()

        self._output_projection = torch.nn.Linear(target_embedding_dim, num_classes)

    def forward(
        self,  # type: ignore
        source_tokens:  Dict[str, torch.LongTensor],
        target_tokens:  Dict[str, torch.LongTensor] = None,
    ) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_input = self._source_embedder(source_tokens)
        # shape: (batch_size, max_input_sequence_length)
        source_mask = util.get_text_field_mask(source_tokens)

        # source_upsampled : shape : (batch_size, max_input_sequence_length, encoder_input_dim * self.s)
        # source_mask_upsampled : shape : (batch_size, max_input_sequence_length)
        source_upsampled, source_mask_upsampled = self._upsample(embedded_input, source_mask)

        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        net_output = self._net(source_upsampled, source_mask_upsampled)
        output_dict = {"source_mask_upsampled": source_mask_upsampled, "net_output": net_output}

        alignment_logits = self._output_projection(net_output)
        output_dict["alignment_logits"] = alignment_logits

        if target_tokens:
            # Compute loss.
            loss = self._get_loss(output_dict, target_tokens)
            output_dict["loss"] = loss

        if not
            alignments = alignment_logits.detach().cpu().argmax(2)
            predictions = self.beta_inverse(alignments)
            output_dict["predictions"] = predictions

            if target_tokens and self._bleu:
                self._bleu(output_dict['predictions'], target_tokens["tokens"])
                #output_dict = self.decode(output_dict)

        return output_dict

    # TODO: too cheap. need pallalel processing
    def beta_inverse(self, a:torch.Tensor):
        a : size (batch, sequence)
        max_length = a.size(1)
        outputs = []

        for sequence in a.tolist():
            output = []
            for token in sequence:
                if token == self._blank_index:
                elif len(output) == 0:
                elif token == output[-1]:
            pad_list = [self._pad_index] * (max_length - len(output))
            outputs.append(output + pad_list)
        return torch.LongTensor(outputs)

    # @staticmethod
    def _get_loss(self, 
        output_dict: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor]) -> torch.Tensor:

        targets = target_tokens["tokens"]
        target_mask = util.get_text_field_mask(target_tokens)
        # shape: (batch_size, input_length, target_size)
        alignment_logits = output_dict["alignment_logits"]

        # shape: (batch_size, input_length)
        source_mask_upsampled = output_dict["source_mask_upsampled"]

        #return util.sequence_cross_entropy_with_logits(alignment_logits, targets, source_mask_upsampled)
        return sequence_ctc_loss_with_logits(alignment_logits, source_mask_upsampled, targets, target_mask, self._blank_index)

    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        Finalize predictions.
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []

        for indices in predicted_indices:
            indices = list(indices)

            # remove pad
            if self._pad_index in indices:
                indices = indices[: indices.index(self._pad_index)]

            # lookup
            predicted_tokens = [
                self.vocab.get_token_from_index(x, namespace=self._target_namespace)
                for x in indices
        # provide "tokens" and "predicted_tokens" for output.
        output_dict["predicted_tokens"] = all_predicted_tokens
        del output_dict["alignment_logits"], output_dict['source_mask_upsampled'], output_dict['net_output']
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not
        return all_metrics
class Bart(Model):
    BART model from the paper "BART: Denosing Sequence-to-Sequence Pre-training for Natural Language Generation,
    Translation, and Comprehension" ( The Bart model here uses a language
    modeling head and thus can be used for text generation.
    def __init__(
        model_name: str,
        vocab: Vocabulary,
        indexer: PretrainedTransformerIndexer = None,
        max_decoding_steps: int = 140,
        beam_size: int = 4,
        encoder: Seq2SeqEncoder = None,
        # Parameters

        model_name : `str`, required
            Name of the pre-trained BART model to use. Available options can be found in
        vocab : `Vocabulary`, required
            Vocabulary containing source and target vocabularies.
        indexer : `PretrainedTransformerIndexer`, optional (default = `None`)
            Indexer to be used for converting decoded sequences of ids to to sequences of tokens.
        max_decoding_steps : `int`, optional (default = `128`)
            Number of decoding steps during beam search.
        beam_size : `int`, optional (default = `5`)
            Number of beams to use in beam search. The default is from the BART paper.
        encoder : `Seq2SeqEncoder`, optional (default = `None`)
            Encoder to used in BART. By default, the original BART encoder is used.
        self.bart = BartForConditionalGeneration.from_pretrained(model_name)
        self._indexer = indexer or PretrainedTransformerIndexer(
            model_name, namespace="tokens")

        self._start_id = self.bart.config.bos_token_id  # CLS
        self._decoder_start_id = self.bart.config.decoder_start_token_id or self._start_id
        self._end_id = self.bart.config.eos_token_id  # SEP
        self._pad_id = self.bart.config.pad_token_id  # PAD

        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_id,
                                       beam_size=beam_size or 1)

        self._rouge = ROUGE(
            exclude_indices={self._start_id, self._pad_id, self._end_id})
        self._bleu = BLEU(
            exclude_indices={self._start_id, self._pad_id, self._end_id})

        # Replace bart encoder with given encoder. We need to extract the two embedding layers so that
        # we can use them in the encoder wrapper
        if encoder is not None:
            assert (encoder.get_input_dim() == encoder.get_output_dim() ==
            self.bart.model.encoder = _BartEncoderWrapper(

    def forward(
            source_tokens: TextFieldTensors,
            target_tokens: TextFieldTensors = None) -> Dict[str, torch.Tensor]:
        Performs the forward step of Bart.

        # Parameters

        source_tokens : `TextFieldTensors`, required
            The source tokens for the encoder. We assume they are stored under the `tokens` key.
        target_tokens : `TextFieldTensors`, optional (default = `None`)
            The target tokens for the decoder. We assume they are stored under the `tokens` key. If no target
            tokens are given, the source tokens are shifted to the right by 1.

        # Returns

        `Dict[str, torch.Tensor]`
            During training, this dictionary contains the `decoder_logits` of shape `(batch_size,
            max_target_length, target_vocab_size)` and the `loss`. During inference, it contains `predictions`
            of shape `(batch_size, max_decoding_steps)` and `log_probabilities` of shape `(batch_size,)`.

        inputs = source_tokens
        targets = target_tokens
        input_ids, input_mask = inputs["tokens"]["token_ids"], inputs[

        outputs = {}

        # If no targets are provided, then shift input to right by 1. Bart already does this internally
        # but it does not use them for loss calculation.
        if targets is not None:
            target_ids, target_mask = targets["tokens"]["token_ids"], targets[
            target_ids = input_ids[:, 1:]
            target_mask = input_mask[:, 1:]

            decoder_logits = self.bart(
                decoder_input_ids=target_ids[:, :-1].contiguous(),
                decoder_attention_mask=target_mask[:, :-1].contiguous(),

            outputs["decoder_logits"] = decoder_logits

            # The BART paper mentions label smoothing of 0.1 for sequence generation tasks
            outputs["loss"] = sequence_cross_entropy_with_logits(
                target_ids[:, 1:].contiguous(),
                target_mask[:, 1:].contiguous(),
            # Use decoder start id and start of sentence to start decoder
            initial_decoder_ids = torch.tensor(
                [[self._decoder_start_id, self._start_id]],
            ).repeat(input_ids.shape[0], 1)

            inital_state = {
                "input_ids": input_ids,
                "input_mask": input_mask,
                "encoder_states": None,
            beam_result =,

            predictions = beam_result[0]
            max_pred_indices = (beam_result[1].argmax(dim=-1).view(
                -1, 1, 1).expand(-1, -1, predictions.shape[-1]))
            predictions = predictions.gather(
                dim=1, index=max_pred_indices).squeeze(dim=1)

            self._rouge(predictions, target_ids)
            self._bleu(predictions, target_ids)

            outputs["predictions"] = predictions
            outputs["log_probabilities"] = (beam_result[1].gather(
                dim=-1, index=max_pred_indices[..., 0]).squeeze(dim=-1))


        return outputs

    def _decoder_cache_to_dict(decoder_cache):
        cache_dict = {}
        for layer_index, layer_cache in enumerate(decoder_cache):
            for attention_name, attention_cache in layer_cache.items():
                for tensor_name, cache_value in attention_cache.items():
                    key = (layer_index, attention_name, tensor_name)
                    cache_dict[key] = cache_value
        return cache_dict

    def _dict_to_decoder_cache(cache_dict):
        decoder_cache = []
        for key, cache_value in cache_dict.items():
            # Split key and extract index and dict keys
            layer_idx, attention_name, tensor_name = key
            # Extend decoder_cache to fit layer_idx + 1 layers
            decoder_cache = decoder_cache + [
                {} for _ in range(layer_idx + 1 - len(decoder_cache))
            cache = decoder_cache[layer_idx]
            if attention_name not in cache:
                cache[attention_name] = {}
            assert tensor_name not in cache[attention_name]
            cache[attention_name][tensor_name] = cache_value
        return decoder_cache

    def take_step(self, last_predictions: torch.Tensor,
                  state: Dict[str, torch.Tensor],
                  step: int) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        Take step during beam search.

        # Parameters

        last_predictions : `torch.Tensor`
            The predicted token ids from the previous step. Shape: `(group_size,)`
        state : `Dict[str, torch.Tensor]`
            State required to generate next set of predictions
        step : `int`
            The time step in beam search decoding.

        # Returns

        `Tuple[torch.Tensor, Dict[str, torch.Tensor]]`
            A tuple containing logits for the next tokens of shape `(group_size, target_vocab_size)` and
            an updated state dictionary.
        if len(last_predictions.shape) == 1:
            last_predictions = last_predictions.unsqueeze(-1)

        # Only the last predictions are needed for the decoder, but we need to pad the decoder ids
        # to not mess up the positional embeddings in the decoder.
        padding_size = 0
        if step > 0:
            padding_size = step + 1
            padding = torch.full(
                (last_predictions.shape[0], padding_size),
            last_predictions =[padding, last_predictions], dim=-1)

        decoder_cache = None
        decoder_cache_dict = {
            k: (state[k].contiguous() if state[k] is not None else None)
            for k in state
            if k not in {"input_ids", "input_mask", "encoder_states"}
        if len(decoder_cache_dict) != 0:
            decoder_cache = self._dict_to_decoder_cache(decoder_cache_dict)

        log_probabilities = None
        for i in range(padding_size, last_predictions.shape[1]):
            encoder_outputs = ((state["encoder_states"], ) if
                               state["encoder_states"] is not None else None)
            outputs = self.bart(
                decoder_input_ids=last_predictions[:, :i + 1],

            decoder_log_probabilities = F.log_softmax(outputs[0][:, 0], dim=-1)

            if log_probabilities is None:
                log_probabilities = decoder_log_probabilities
                idx = last_predictions[:, i].view(-1, 1)
                log_probabilities = decoder_log_probabilities + log_probabilities.gather(
                    dim=-1, index=idx)

            decoder_cache = outputs[1]

            state["encoder_states"] = outputs[2]

        if decoder_cache is not None:
            decoder_cache_dict = self._decoder_cache_to_dict(decoder_cache)

        return log_probabilities, state

    def make_output_human_readable(
            self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:

        # Parameters

        output_dict : `Dict[str, torch.Tensor]`
            A dictionary containing a batch of predictions with key `predictions`. The tensor should have
            shape `(batch_size, max_sequence_length)`

        # Returns

        `Dict[str, Any]`
            Original `output_dict` with an additional `predicted_tokens` key that maps to a list of lists of

        predictions = output_dict["predictions"]
        predicted_tokens = [None] * predictions.shape[0]
        for i in range(predictions.shape[0]):
            predicted_tokens[i] = self._indexer.indices_to_tokens(
                {"token_ids": predictions[0].tolist()}, self.vocab)
        output_dict["predicted_tokens"] = predicted_tokens

        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metrics: Dict[str, float] = {}
        if not
        return metrics
    def __init__(self,
                 vocab: Vocabulary,
                 bidaf_model: BidirectionalAttentionFlowModified,
                 source_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 attention: Attention,
                 beam_size: int,
                 max_decoding_steps: int,
                 target_embedding_dim: int = 30,
                 copy_token: str = "@COPY@",
                 source_namespace: str = "source_tokens",
                 target_namespace: str = "target_tokens",
                 tensor_based_metric: Metric = None,
                 token_based_metric: Metric = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 dropout: float = 0.0,
                 pretrained_bidaf: bool = False) -> None:

        if pretrained_bidaf:
            params = Params.from_file("./temp/bidaf_baseline/config.json")
            vocab = Vocabulary.from_files("./temp/bidaf_baseline/vocabulary")
            self.bidaf_model = Model.from_params(vocab=vocab,
            map_location = None if torch.cuda.is_available() else 'cpu'
            with open("./temp/bidaf_baseline/", 'rb') as f:
                    torch.load(f, map_location=map_location))
            self.bidaf_model = bidaf_model
        self._source_namespace = source_namespace
        self._target_namespace = target_namespace
        self._src_start_index = self.vocab.get_token_index(
            START_SYMBOL, self._source_namespace)
        self._src_end_index = self.vocab.get_token_index(
            END_SYMBOL, self._source_namespace)
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
        self._end_index = self.vocab.get_token_index(END_SYMBOL,
        self._oov_index = self.vocab.get_token_index(self.vocab._oov_token,
                                                     self._target_namespace)  # pylint: disable=protected-access
        self._pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                     self._target_namespace)  # pylint: disable=protected-access
        self._copy_index = self.vocab.add_token_to_namespace(
            copy_token, self._target_namespace)

        self._tensor_based_metric = tensor_based_metric or \
            BLEU(exclude_indices={self._pad_index, self._end_index, self._start_index})
        self._token_based_metric = token_based_metric

        self._target_vocab_size = self.vocab.get_vocab_size(

        # Encoding modules.
        self._source_embedder = source_embedder
        self._encoder = encoder

        # Decoder output dim needs to be the same as the encoder output dim since we initialize the
        # hidden state of the decoder with the final hidden state of the encoder.
        # We arbitrarily set the decoder's input dimension to be the same as the output dimension.
        self.encoder_output_dim = self._encoder.get_output_dim()
        self.decoder_output_dim = self.encoder_output_dim
        self.decoder_input_dim = self.decoder_output_dim

        modeling_dim = self.bidaf_model._modeling_layer.get_output_dim()
        self._init_decoder_projection = Linear(
            self.encoder_output_dim + modeling_dim, self.decoder_output_dim)

        target_vocab_size = self.vocab.get_vocab_size(self._target_namespace)

        # The decoder input will be a function of the embedding of the previous predicted token,
        # an attended encoder hidden state called the "attentive read", and another
        # weighted sum of the encoder hidden state called the "selective read".
        # While the weights for the attentive read are calculated by an `Attention` module,
        # the weights for the selective read are simply the predicted probabilities
        # corresponding to each token in the source sentence that matches the target
        # token from the previous timestep.
        self._target_embedder = Embedding(target_vocab_size,
        self._attention = attention
        self._input_projection_layer = Linear(
            target_embedding_dim + self.encoder_output_dim * 2,

        # We then run the projected decoder input through an LSTM cell to produce
        # the next hidden state.
        self._decoder_cell = LSTMCell(self.decoder_input_dim,

        # We create a "generation" score for each token in the target vocab
        # with a linear projection of the decoder hidden state.
        self._output_generation_layer = Linear(self.decoder_output_dim,

        # We create a "copying" score for each source token by applying a non-linearity
        # (tanh) to a linear projection of the encoded hidden state for that token,
        # and then taking the dot product of the result with the decoder hidden state.
        self._output_copying_layer = Linear(self.encoder_output_dim,

        # At prediction time, we'll use a beam search to find the best target sequence.
        self._beam_search = BeamSearch(self._end_index,

        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        vocab: Vocabulary,
        decoder_net: DecoderNet,
        max_decoding_steps: int,
        target_embedder: Embedding,
        target_namespace: str = "tokens",
        tie_output_embedding: bool = False,
        scheduled_sampling_ratio: float = 0,
        label_smoothing_ratio: Optional[float] = None,
        beam_size: int = 4,
        tensor_based_metric: Metric = None,
        token_based_metric: Metric = None,
        bleu_exclude_tokens: List = [],
    ) -> None:

        self._vocab = vocab

        # Decodes the sequence of encoded hidden states into e new sequence of hidden states.
        self._decoder_net = decoder_net
        self._max_decoding_steps = max_decoding_steps
        self._target_namespace = target_namespace
        self._label_smoothing_ratio = label_smoothing_ratio

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self._vocab.get_token_index(
            START_SYMBOL, self._target_namespace
        self._end_index = self._vocab.get_token_index(
            END_SYMBOL, self._target_namespace
        self._beam_search = BeamSearch(
            self._end_index, max_steps=max_decoding_steps, beam_size=beam_size

        target_vocab_size = self._vocab.get_vocab_size(self._target_namespace)

        if (
            != self._decoder_net.target_embedding_dim
            raise ConfigurationError(
                "Target Embedder output_dim doesn't match decoder module's input."

        # We project the hidden state from the decoder into the output vocabulary space
        # in order to get log probabilities of each target token, at each time step.
        self._output_projection_layer = Linear(
            self._decoder_net.get_output_dim(), target_vocab_size

        if tie_output_embedding:
            if (
                != self.target_embedder.weight.shape
                raise ConfigurationError(
                    "Can't tie embeddings with output linear layer, due to shape mismatch"
            self._output_projection_layer.weight = self.target_embedder.weight

        # These metrics will be updated during training and validation
        if isinstance(tensor_based_metric, BLEU):
            pad_index = self._vocab.get_token_index(
                self._vocab._padding_token, self._target_namespace
            new_exclude_indices = set([pad_index])
            for token in bleu_exclude_tokens:
                    self._vocab.get_token_index(token, self._target_namespace)
                f"Reconstruct BLEU to exclude indices {' '.join(map(str, new_exclude_indices))}"
            self._tensor_based_metric = BLEU(
                tensor_based_metric._ngram_weights, new_exclude_indices
            self._tensor_based_metric = tensor_based_metric
        self._token_based_metric = token_based_metric

        self._scheduled_sampling_ratio = scheduled_sampling_ratio
    def _forward_beam_search(
        self, state: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        Prepare inputs for the beam search, does beam search and returns beam search results.
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size,), fill_value=self._start_index

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities =
            start_predictions, state, self.take_step

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        return output_dict

    def _forward_loss(
        self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor]
    ) -> Dict[str, torch.Tensor]:
        Make forward pass during training or do greedy search during prediction.

        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (batch_size, max_target_sequence_length)
        targets = target_tokens["tokens"]

        # Prepare embeddings for targets. They will be used as gold embeddings during decoder training
        # shape: (batch_size, max_target_sequence_length, embedding_dim)
        target_embedding = self.target_embedder(targets)

        # shape: (batch_size, max_target_batch_sequence_length)
        target_mask = util.get_text_field_mask(target_tokens)

        if self._scheduled_sampling_ratio == 0 and self._decoder_net.decodes_parallel:
            _, decoder_output = self._decoder_net(
                previous_steps_predictions=target_embedding[:, :-1, :],
                previous_steps_mask=target_mask[:, :-1],

            # shape: (group_size, max_target_sequence_length, num_classes)
            logits = self._output_projection_layer(decoder_output)
            batch_size = source_mask.size()[0]
            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1

            # Initialize target predictions with the start index.
            # shape: (batch_size,)
            last_predictions = source_mask.new_full(
                (batch_size,), fill_value=self._start_index

            # shape: (steps, batch_size, target_embedding_dim)
            steps_embeddings = torch.Tensor([])

            step_logits: List[torch.Tensor] = []

            for timestep in range(num_decoding_steps):
                if (
                    and torch.rand(1).item() < self._scheduled_sampling_ratio
                    # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                    # during training.
                    # shape: (batch_size, steps, target_embedding_dim)
                    state["previous_steps_predictions"] = steps_embeddings

                    # shape: (batch_size, )
                    effective_last_prediction = last_predictions
                    # shape: (batch_size, )
                    effective_last_prediction = targets[:, timestep]

                    if timestep == 0:
                        state["previous_steps_predictions"] = torch.Tensor([])
                        # shape: (batch_size, steps, target_embedding_dim)
                        state["previous_steps_predictions"] = target_embedding[
                            :, :timestep

                # shape: (batch_size, num_classes)
                output_projections, state = self._prepare_output_projections(
                    effective_last_prediction, state

                # list of tensors, shape: (batch_size, 1, num_classes)

                # shape (predicted_classes): (batch_size,)
                _, predicted_classes = torch.max(output_projections, 1)

                # shape (predicted_classes): (batch_size,)
                last_predictions = predicted_classes

                # shape: (batch_size, 1, target_embedding_dim)
                last_predictions_embeddings = self.target_embedder(

                # This step is required, since we want to keep up two different prediction history: gold and real
                if steps_embeddings.shape[-1] == 0:
                    # There is no previous steps, except for start vectors in ``last_predictions``
                    # shape: (group_size, 1, target_embedding_dim)
                    steps_embeddings = last_predictions_embeddings
                    # shape: (group_size, steps_count, target_embedding_dim)
                    steps_embeddings =
                        [steps_embeddings, last_predictions_embeddings], 1

            # shape: (batch_size, num_decoding_steps, num_classes)
            logits =, 1)

        # Compute loss.
        target_mask = util.get_text_field_mask(target_tokens)
        loss = self._get_loss(logits, targets, target_mask)

        # TODO: We will be using beam search to get predictions for validation, but if beam size in 1
        # we could consider taking the last_predictions here and building step_predictions
        # and use that instead of running beam search again, if performance in validation is taking a hit
        output_dict = {"loss": loss}

        return output_dict

    def _prepare_output_projections(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, steps_count, decoder_output_dim)
        previous_steps_predictions = state.get("previous_steps_predictions")

        # shape: (batch_size, 1, target_embedding_dim)
        last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze(

        if (
            previous_steps_predictions is None
            or previous_steps_predictions.shape[-1] == 0
            # There is no previous steps, except for start vectors in ``last_predictions``
            # shape: (group_size, 1, target_embedding_dim)
            previous_steps_predictions = last_predictions_embeddings
            # shape: (group_size, steps_count, target_embedding_dim)
            previous_steps_predictions =
                [previous_steps_predictions, last_predictions_embeddings], 1

        decoder_state, decoder_output = self._decoder_net(
        state["previous_steps_predictions"] = previous_steps_predictions

        # Update state with new decoder state, override previous state

        if self._decoder_net.decodes_parallel:
            decoder_output = decoder_output[:, -1, :]

        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_output)

        return output_projections, state

    def _get_loss(
        logits: torch.LongTensor,
        targets: torch.LongTensor,
        target_mask: torch.LongTensor,
    ) -> torch.Tensor:
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(

    def get_output_dim(self):
        return self._decoder_net.get_output_dim()

    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        Take a decoding step. This is called by the beam search class.

        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if not
            if self._tensor_based_metric is not None:
                    self._tensor_based_metric.get_metric(reset=reset)  # type: ignore
            if self._token_based_metric is not None:
                all_metrics.update(self._token_based_metric.get_metric(reset=reset))  # type: ignore
        return all_metrics

    def forward(
        encoder_out: Dict[str, torch.LongTensor],
        target_tokens: Dict[str, torch.LongTensor] = None,
    ) -> Dict[str, torch.Tensor]:
        state = encoder_out
        if target_tokens:
            decoder_init_state = self._decoder_net.init_decoder_state(state)
            output_dict = self._forward_loss(decoder_init_state, target_tokens)
            output_dict = {}

        if not
            decoder_init_state = self._decoder_net.init_decoder_state(state)
            predictions = self._forward_beam_search(decoder_init_state)

            if target_tokens:
                if self._tensor_based_metric is not None:
                    # shape: (batch_size, beam_size, max_sequence_length)
                    top_k_predictions = output_dict["predictions"]
                    # shape: (batch_size, max_predicted_sequence_length)
                    best_predictions = top_k_predictions[:, 0, :]

                    self._tensor_based_metric(  # type: ignore
                        best_predictions, target_tokens["tokens"]

                if self._token_based_metric is not None:
                    output_dict = self.post_process(output_dict)
                    predicted_tokens = output_dict["predicted_tokens"]

                    self._token_based_metric(  # type: ignore
                        self.indices_to_tokens(target_tokens["tokens"][:, 1:]),

        return output_dict

    def post_process(
        self, output_dict: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        predicted_indices = output_dict["predictions"]
        all_predicted_tokens = self.indices_to_tokens(predicted_indices)
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def indices_to_tokens(self, batch_indeces: numpy.ndarray) -> List[List[str]]:

        if not isinstance(batch_indeces, numpy.ndarray):
            batch_indeces = batch_indeces.detach().cpu().numpy()

        all_tokens = []
        for indices in batch_indeces:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[: indices.index(self._end_index)]
            tokens = [
                self._vocab.get_token_from_index(x, namespace=self._target_namespace)
                for x in indices

        return all_tokens
    def __init__(
            vocab: Vocabulary,
            source_embedder: TextFieldEmbedder,
            target_embedder: TextFieldEmbedder,
            source_encoder: Seq2SeqEncoder,
            target_encoder: Seq2SeqEncoder,
            max_decoding_steps: int,
            attention: Attention = None,
            s2s_attention: Attention = None,
            t2t_attention: Attention = None,
            beam_size: int = None,
            target_namespace: str = "tokens",
            scheduled_sampling_ratio: float = 0.,
            use_bleu: bool = True) -> None:
        super(AssociativeSeq2SeqChainedAttention, self).__init__(vocab)
        self._target_namespace = target_namespace
        self._scheduled_sampling_ratio = scheduled_sampling_ratio

        # We need the start symbol to provide as the input at the first timestep of decoding, and
        # end symbol as a way to indicate the end of the decoded sequence.
        self._start_index = self.vocab.get_token_index(START_SYMBOL,
        self._end_index = self.vocab.get_token_index(END_SYMBOL,

        if use_bleu:
            pad_index = self.vocab.get_token_index(self.vocab._padding_token,
                                                   self._target_namespace)  # pylint: disable=protected-access
            self._bleu = BLEU(exclude_indices={
                pad_index, self._end_index, self._start_index
            self._bleu = None

        # At prediction time, we use a beam search to find the most likely sequence of target tokens.
        beam_size = beam_size or 1
        self._max_decoding_steps = max_decoding_steps
        self._beam_search = BeamSearch(self._end_index,

        # Dense embedding of source vocab tokens.
        self._source_embedder = source_embedder

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        self._source_encoder = source_encoder
        self._target_encoder = target_encoder

        self._encoder_output_dim = self._target_encoder.get_output_dim()
        self._decoder_output_dim = self._encoder_output_dim
        target_embedding_dim = target_embedder.get_output_dim()

        if attention:
            self._attention = attention
            self._s2s_attention = s2s_attention
            self._t2t_attention = t2t_attention
            self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim

            raise NotImplementedError

        num_classes = self.vocab.get_vocab_size(self._target_namespace)
        self._target_embedder = target_embedder

        self._decoder_cell = LSTMCell(self._decoder_input_dim,
        self._output_projection_layer = Linear(self._decoder_output_dim,
    def take_step(
        self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        Take a decoding step. This is called by the beam search class.

        last_predictions : ``torch.Tensor``
            A tensor of shape ``(group_size,)``, which gives the indices of the predictions
            during the last time step.
        state : ``Dict[str, torch.Tensor]``
            A dictionary of tensors that contain the current state information
            needed to predict the next step, which includes the encoder outputs,
            the source mask, and the decoder hidden state and context. Each of these
            tensors has shape ``(group_size, *)``, where ``*`` can be any other number
            of dimensions.

        Tuple[torch.Tensor, Dict[str, torch.Tensor]]
            A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
            is a tensor of shape ``(group_size, num_classes)`` containing the predicted
            log probability of each class for the next step, for each item in the group,
            while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
            source mask, and updated decoder hidden state and context.

            We treat the inputs as a batch, even though ``group_size`` is not necessarily
            equal to ``batch_size``, since the group may contain multiple states
            for each source sentence in the batch.
        # shape: (group_size, num_classes)
        output_projections, state = self._prepare_output_projections(
            last_predictions, state)

        # shape: (group_size, num_classes)
        class_log_probabilities = F.log_softmax(output_projections, dim=-1)

        return class_log_probabilities, state

    def forward(
        self,  # type: ignore
        ref_source_tokens: Dict[str, torch.LongTensor],
        instance_source_tokens: Dict[str, torch.LongTensor],
        ref_target_tokens: Dict[str, torch.LongTensor],
        instance_target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:

        state = self._encode(ref_target_tokens, ref_source_tokens,
        if instance_target_tokens:
            state = self._init_decoder_state(state)
            # The `_forward_loop` decodes the input sequence and computes the loss during training
            # and validation.
            output_dict = self._forward_loop(state, instance_target_tokens)
            output_dict = {}

        if not
            state = self._init_decoder_state(state)
            predictions = self._forward_beam_search(state)
            if instance_target_tokens and self._bleu:
                # shape: (batch_size, beam_size, max_sequence_length)
                top_k_predictions = output_dict["predictions"]
                # shape: (batch_size, max_predicted_sequence_length)
                best_predictions = top_k_predictions[:, 0, :]
                self._bleu(best_predictions, instance_target_tokens["tokens"])

        return output_dict

    def decode(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        Finalize predictions.

        This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
        time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
        within the ``forward`` method.

        This method trims the output predictions to the first end symbol, replaces indices with
        corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
        predicted_indices = output_dict["predictions"]
        if not isinstance(predicted_indices, numpy.ndarray):
            predicted_indices = predicted_indices.detach().cpu().numpy()
        all_predicted_tokens = []
        for indices in predicted_indices:
            # Beam search gives us the top k results for each source sentence in the batch
            # but we just want the single best.
            if len(indices.shape) > 1:
                indices = indices[0]
            indices = list(indices)
            # Collect indices till the first end_symbol
            if self._end_index in indices:
                indices = indices[:indices.index(self._end_index)]
            predicted_tokens = [
                    x, namespace=self._target_namespace) for x in indices
        output_dict["predicted_tokens"] = all_predicted_tokens
        return output_dict

    def _encode(
        self, ref_target_tokens: Dict[str, torch.Tensor],
        ref_source_tokens: Dict[str, torch.Tensor],
        instance_source_tokens: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
        embedded_ref_target = self._target_embedder(ref_target_tokens)
        ref_target_mask = util.get_text_field_mask(ref_target_tokens)
        encoded_ref_target = self._target_encoder(embedded_ref_target,

        embedded_ref_source = self._source_embedder(ref_source_tokens)
        ref_source_mask = util.get_text_field_mask(ref_source_tokens)
        encoded_ref_source = self._source_encoder(embedded_ref_source,

        embedded_instance_source = self._source_embedder(
        instance_source_mask = util.get_text_field_mask(instance_source_tokens)
        encoded_instance_source = self._source_encoder(
            embedded_instance_source, instance_source_mask)

        return {
            "source_mask": ref_target_mask,
            "encoder_outputs": encoded_ref_target,
            "ref_source_mask": ref_source_mask,
            "ref_source_encoded": embedded_ref_source,
            "instance_source_mask": instance_source_mask,
            "instance_source_encoded": embedded_instance_source

    def _init_decoder_state(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        batch_size = state["source_mask"].size(0)
        # shape: (batch_size, encoder_output_dim)
        final_encoder_output = util.get_final_encoder_states(
            state["encoder_outputs"], state["source_mask"],
        # Initialize the decoder hidden state with the final output of the encoder.
        # shape: (batch_size, decoder_output_dim)
        state["decoder_hidden"] = final_encoder_output
        # shape: (batch_size, decoder_output_dim)
        state["decoder_context"] = state["encoder_outputs"].new_zeros(
            batch_size, self._decoder_output_dim)
        return state

    def _forward_loop(
        state: Dict[str, torch.Tensor],
        target_tokens: Dict[str, torch.LongTensor] = None
    ) -> Dict[str, torch.Tensor]:
        Make forward pass during training or do greedy search during prediction.

        We really only use the predictions from the method to test that beam search
        with a beam size of 1 gives the same results.
        # shape: (batch_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        batch_size = source_mask.size()[0]

        if target_tokens:
            # shape: (batch_size, max_target_sequence_length)
            targets = target_tokens["tokens"]

            _, target_sequence_length = targets.size()

            # The last input from the target is either padding or the end symbol.
            # Either way, we don't have to process it.
            num_decoding_steps = target_sequence_length - 1
            num_decoding_steps = self._max_decoding_steps

        # Initialize target predictions with the start index.
        # shape: (batch_size,)
        last_predictions = source_mask.new_full((batch_size, ),

        step_logits: List[torch.Tensor] = []
        step_predictions: List[torch.Tensor] = []
        for timestep in range(num_decoding_steps):
            if and torch.rand(
                    1).item() < self._scheduled_sampling_ratio:
                # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
                # during training.
                # shape: (batch_size,)
                input_choices = last_predictions
            elif not target_tokens:
                # shape: (batch_size,)
                input_choices = last_predictions
                # shape: (batch_size,)
                input_choices = targets[:, timestep]

            # shape: (batch_size, num_classes)
            output_projections, state = self._prepare_output_projections(
                input_choices, state)

            # list of tensors, shape: (batch_size, 1, num_classes)

            # shape: (batch_size, num_classes)
            class_probabilities = F.softmax(output_projections, dim=-1)

            # shape (predicted_classes): (batch_size,)
            _, predicted_classes = torch.max(class_probabilities, 1)

            # shape (predicted_classes): (batch_size,)
            last_predictions = predicted_classes


        # shape: (batch_size, num_decoding_steps)
        predictions =, 1)

        output_dict = {"predictions": predictions}

        if target_tokens:
            # shape: (batch_size, num_decoding_steps, num_classes)
            logits =, 1)

            # Compute loss.
            target_mask = util.get_text_field_mask(target_tokens)
            loss = self._get_loss(logits, targets, target_mask)
            output_dict["loss"] = loss

        return output_dict

    def _forward_beam_search(
            self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Make forward pass during prediction using a beam search."""
        batch_size = state["source_mask"].size()[0]
        start_predictions = state["source_mask"].new_full(
            (batch_size, ), fill_value=self._start_index)

        # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
        # shape (log_probabilities): (batch_size, beam_size)
        all_top_k_predictions, log_probabilities =
            start_predictions, state, self.take_step)

        output_dict = {
            "class_log_probabilities": log_probabilities,
            "predictions": all_top_k_predictions,
        return output_dict

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:  # pylint: disable=line-too-long
        Decode current state and last prediction to produce produce projections
        into the target space, which can then be used to get probabilities of
        each target token for the next step.

        Inputs are the same as for `take_step()`.
        # shape: (group_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs = state["encoder_outputs"]

        # shape: (group_size, max_input_sequence_length)
        source_mask = state["source_mask"]

        # shape: (group_size, decoder_output_dim)
        decoder_hidden = state["decoder_hidden"]

        # shape: (group_size, decoder_output_dim)
        decoder_context = state["decoder_context"]

        # shape: (group_size, target_embedding_dim)
        embedded_input = self._target_embedder._token_embedders['tokens'](

        if self._attention:
            # shape: (group_size, encoder_output_dim)
            attended_input_source = self._prepare_attended_input(
                decoder_hidden, state['encoder_outputs'], state['source_mask'],
            attended_input_ref_source = self._prepare_attended_input(
                attended_input_source, state['ref_source_encoded'],
                state['ref_source_mask'], self._attention)
            attended_input_instance_source = self._prepare_attended_input(
                attended_input_ref_source, state['instance_source_encoded'],
                state['instance_source_mask'], self._s2s_attention)

            # shape: (group_size, decoder_output_dim + target_embedding_dim)
            decoder_input =
                (attended_input_instance_source, embedded_input), -1)
            raise NotImplementedError

        # shape (decoder_hidden): (batch_size, decoder_output_dim)
        # shape (decoder_context): (batch_size, decoder_output_dim)
        decoder_hidden, decoder_context = self._decoder_cell(
            decoder_input, (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context
        # shape: (group_size, num_classes)
        output_projections = self._output_projection_layer(decoder_hidden)

        return output_projections, state

    def _prepare_attended_input(self,
                                decoder_hidden_state: torch.LongTensor = None,
                                encoder_outputs: torch.LongTensor = None,
                                encoder_outputs_mask: torch.LongTensor = None,
                                attention=None) -> torch.Tensor:
        """Apply attention over encoder outputs and decoder state."""
        # Ensure mask is also a FloatTensor. Or else the multiplication within
        # attention will complain.
        # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
        encoder_outputs_mask = encoder_outputs_mask.float()

        # shape: (batch_size, max_input_sequence_length)
        input_weights = attention(decoder_hidden_state, encoder_outputs,

        # shape: (batch_size, encoder_output_dim)
        attended_input = util.weighted_sum(encoder_outputs, input_weights)

        return attended_input

    def _get_loss(logits: torch.LongTensor, targets: torch.LongTensor,
                  target_mask: torch.LongTensor) -> torch.Tensor:
        Compute loss.

        Takes logits (unnormalized outputs from the decoder) of size (batch_size,
        num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
        and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
        entropy loss while taking the mask into account.

        The length of ``targets`` is expected to be greater than that of ``logits`` because the
        decoder does not need to compute the output corresponding to the last timestep of
        ``targets``. This method aligns the inputs appropriately to compute the loss.

        During training, we want the logit corresponding to timestep i to be similar to the target
        token from timestep i + 1. That is, the targets should be shifted by one timestep for
        appropriate comparison.  Consider a single example where the target has 3 words, and
        padding is to 7 tokens.
           The complete sequence would correspond to <S> w1  w2  w3  <E> <P> <P>
           and the mask would be                     1   1   1   1   1   0   0
           and let the logits be                     l1  l2  l3  l4  l5  l6
        We actually need to compare:
           the sequence           w1  w2  w3  <E> <P> <P>
           with masks             1   1   1   1   0   0
           against                l1  l2  l3  l4  l5  l6
           (where the input was)  <S> w1  w2  w3  <E> <P>
        # shape: (batch_size, num_decoding_steps)
        relevant_targets = targets[:, 1:].contiguous()

        # shape: (batch_size, num_decoding_steps)
        relevant_mask = target_mask[:, 1:].contiguous()

        return util.sequence_cross_entropy_with_logits(logits,

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        all_metrics: Dict[str, float] = {}
        if self._bleu and not
        return all_metrics