Esempio n. 1
0
class RecurrentDecoder(Decoder):
    """A conditional RNN decoder with attention."""

    def __init__(self,
                 type: str = "gru",
                 emb_size: int = 0,
                 hidden_size: int = 0,
                 encoder: Encoder = None,
                 attention: str = "bahdanau",
                 num_layers: int = 0,
                 vocab_size: int = 0,
                 dropout: float = 0.,
                 hidden_dropout: float = 0.,
                 bridge: bool = False,
                 input_feeding: bool = True,
                 freeze: bool = False,
                 **kwargs):
        """
        Create a recurrent decoder.
        If `bridge` is True, the decoder hidden states are initialized from a
        projection of the encoder states, else they are initialized with zeros.

        :param type:
        :param emb_size:
        :param hidden_size:
        :param encoder:
        :param attention:
        :param num_layers:
        :param vocab_size:
        :param dropout:
        :param hidden_dropout:
        :param bridge:
        :param input_feeding:
        :param freeze: freeze the parameters of the decoder during training
        :param kwargs:
        """

        super(RecurrentDecoder, self).__init__()

        self.rnn_input_dropout = torch.nn.Dropout(p=dropout, inplace=False)
        self.type = type
        self.hidden_dropout = torch.nn.Dropout(p=hidden_dropout, inplace=False)
        self.hidden_size = hidden_size

        rnn = nn.GRU if type == "gru" else nn.LSTM

        self.input_feeding = input_feeding
        if self.input_feeding: # Luong-style
            # combine embedded prev word +attention vector before feeding to rnn
            self.rnn_input_size = emb_size + hidden_size
        else:
            # just feed prev word embedding
            self.rnn_input_size = emb_size

        # the decoder RNN
        self.rnn = rnn(self.rnn_input_size, hidden_size, num_layers,
                       batch_first=True,
                       dropout=dropout if num_layers > 1 else 0.)

        # combine output with context vector before output layer (Luong-style)
        self.att_vector_layer = nn.Linear(
            hidden_size + encoder.output_size, hidden_size, bias=True)

        self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False)
        self.output_size = vocab_size

        if attention == "bahdanau":
            self.attention = BahdanauAttention(hidden_size=hidden_size,
                                               key_size=encoder.output_size,
                                               query_size=hidden_size)
        elif attention == "luong":
            self.attention = LuongAttention(hidden_size=hidden_size,
                                            key_size=encoder.output_size)
        else:
            raise ValueError("Unknown attention mechanism: %s" % attention)

        self.num_layers = num_layers
        self.hidden_size = hidden_size

        # to initialize from the final encoder state of last layer
        self.bridge = bridge
        if self.bridge:
            self.bridge_layer = nn.Linear(
                encoder.output_size, hidden_size, bias=True)

        if freeze:
            freeze_params(self)

    def _forward_step(self,
                      prev_embed: Tensor = None,
                      prev_att_vector: Tensor = None,  # context or att vector
                      encoder_output: Tensor = None,
                      src_mask: Tensor = None,
                      hidden: Tensor = None):
        """
        Perform a single decoder step (1 word)

        :param prev_embed:
        :param prev_att_vector:
        :param encoder_output:
        :param src_mask:
        :param hidden:
        :return:
        """

        # loop:
        # 1. rnn input = concat(prev_embed, prev_output [possibly empty])
        # 2. update RNN with rnn_input
        # 3. calculate attention and context/attention vector
        # 4. repeat

        # update rnn hidden state
        if self.input_feeding:
            rnn_input = torch.cat([prev_embed, prev_att_vector], dim=2)
        else:
            rnn_input = prev_embed

        rnn_input = self.rnn_input_dropout(rnn_input)

        # rnn_input: batch x 1 x emb+2*enc_size
        _, hidden = self.rnn(rnn_input, hidden)

        # use new (top) decoder layer as attention query
        if isinstance(hidden, tuple):
            query = hidden[0][-1].unsqueeze(1)
        else:
            query = hidden[-1].unsqueeze(1)  # [#layers, B, D] -> [B, 1, D]

        # compute context vector using attention mechanism
        # only use last layer for attention mechanism
        # key projections are pre-computed
        context, att_probs = self.attention(
            query=query, values=encoder_output, mask=src_mask)

        # return attention vector (Luong)
        # combine context with decoder hidden state before prediction
        att_vector_input = torch.cat([query, context], dim=2)
        att_vector_input = self.hidden_dropout(att_vector_input)

        # batch x 1 x 2*enc_size+hidden_size
        att_vector = torch.tanh(self.att_vector_layer(att_vector_input))

        # output: batch x 1 x dec_size
        return att_vector, hidden, att_probs

    def forward(self, trg_embed, encoder_output, encoder_hidden,
                src_mask, unrol_steps, hidden=None, prev_att_vector=None):
        """
         Unroll the decoder one step at a time for `unrol_steps` steps.

        :param trg_embed:
        :param encoder_output:
        :param encoder_hidden:
        :param src_mask:
        :param unrol_steps:
        :param hidden:
        :param prev_att_vector:
        :return:
        """

        # initialize decoder hidden state from final encoder hidden state
        if hidden is None:
            hidden = self.init_hidden(encoder_hidden)

        # pre-compute projected encoder outputs
        # (the "keys" for the attention mechanism)
        # this is only done for efficiency
        if hasattr(self.attention, "compute_proj_keys"):
            self.attention.compute_proj_keys(encoder_output)

        # here we store all intermediate attention vectors (used for prediction)
        att_vectors = []
        att_probs = []

        batch_size = encoder_output.size(0)

        if prev_att_vector is None:
            with torch.no_grad():
                prev_att_vector = encoder_output.new_zeros(
                    [batch_size, 1, self.hidden_size])

        # unroll the decoder RN N for max_len steps
        for i in range(unrol_steps):
            prev_embed = trg_embed[:, i].unsqueeze(1)  # batch, 1, emb
            prev_att_vector, hidden, att_prob = self._forward_step(
                prev_embed=prev_embed,
                prev_att_vector=prev_att_vector,
                encoder_output=encoder_output,
                src_mask=src_mask,
                hidden=hidden)
            att_vectors.append(prev_att_vector)
            att_probs.append(att_prob)

        att_vectors = torch.cat(att_vectors, dim=1)
        att_probs = torch.cat(att_probs, dim=1)
        # att_probs: batch, max_len, src_length
        outputs = self.output_layer(att_vectors)
        # outputs: batch, max_len, vocab_size
        return outputs, hidden, att_probs, att_vectors

    def init_hidden(self, encoder_final):
        """
        Returns the initial decoder state,
        conditioned on the final encoder state of the last encoder layer.

        :param encoder_final:
        :return:
        """
        batch_size = encoder_final.size(0)

        # for multiple layers: is the same for all layers
        if self.bridge and encoder_final is not None:
            h = torch.tanh(
                self.bridge_layer(encoder_final)).unsqueeze(0).repeat(
                self.num_layers, 1, 1)  # num_layers x batch_size x hidden_size

        else:  # initialize with zeros
            with torch.no_grad():
                h = encoder_final.new_zeros(self.num_layers, batch_size,
                                            self.hidden_size)

        return (h, h) if isinstance(self.rnn, nn.LSTM) else h

    def __repr__(self):
        return "RecurrentDecoder(rnn=%r, attention=%r)" % (
            self.rnn, self.attention)
Esempio n. 2
0
class RecurrentDecoder(Decoder):
    """A conditional RNN decoder with attention."""
    def __init__(self,
                 rnn_type: str = "gru",
                 emb_size: int = 0,
                 hidden_size: int = 0,
                 encoder: Encoder = None,
                 attention: str = "bahdanau",
                 num_layers: int = 1,
                 vocab_size: int = 0,
                 dropout: float = 0.,
                 emb_dropout: float = 0.,
                 hidden_dropout: float = 0.,
                 init_hidden: str = "bridge",
                 input_feeding: bool = True,
                 freeze: bool = False,
                 **kwargs) -> None:
        """
        Create a recurrent decoder with attention.

        :param rnn_type: rnn type, valid options: "lstm", "gru"
        :param emb_size: target embedding size
        :param hidden_size: size of the RNN
        :param encoder: encoder connected to this decoder
        :param attention: type of attention, valid options: "bahdanau", "luong"
        :param num_layers: number of recurrent layers
        :param vocab_size: target vocabulary size
        :param hidden_dropout: Is applied to the input to the attentional layer.
        :param dropout: Is applied between RNN layers.
        :param emb_dropout: Is applied to the RNN input (word embeddings).
        :param init_hidden: If "bridge" (default), the decoder hidden states are
            initialized from a projection of the last encoder state,
            if "zeros" they are initialized with zeros,
            if "last" they are identical to the last encoder state
            (only if they have the same size)
        :param input_feeding: Use Luong's input feeding.
        :param freeze: Freeze the parameters of the decoder during training.
        :param kwargs:
        """

        super().__init__()

        self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False)
        self.type = rnn_type
        self.hidden_dropout = torch.nn.Dropout(p=hidden_dropout, inplace=False)
        self.hidden_size = hidden_size
        self.emb_size = emb_size

        rnn = nn.GRU if rnn_type == "gru" else nn.LSTM

        self.input_feeding = input_feeding
        if self.input_feeding:  # Luong-style
            # combine embedded prev word +attention vector before feeding to rnn
            self.rnn_input_size = emb_size + hidden_size
        else:
            # just feed prev word embedding
            self.rnn_input_size = emb_size

        # the decoder RNN
        self.rnn = rnn(self.rnn_input_size,
                       hidden_size,
                       num_layers,
                       batch_first=True,
                       dropout=dropout if num_layers > 1 else 0.)

        # combine output with context vector before output layer (Luong-style)
        self.att_vector_layer = nn.Linear(hidden_size + encoder.output_size,
                                          hidden_size,
                                          bias=True)

        self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False)
        self._output_size = vocab_size

        if attention == "bahdanau":
            self.attention = BahdanauAttention(hidden_size=hidden_size,
                                               key_size=encoder.output_size,
                                               query_size=hidden_size)
        elif attention == "luong":
            self.attention = LuongAttention(hidden_size=hidden_size,
                                            key_size=encoder.output_size)
        else:
            raise ConfigurationError("Unknown attention mechanism: %s. "
                                     "Valid options: 'bahdanau', 'luong'." %
                                     attention)

        self.num_layers = num_layers
        self.hidden_size = hidden_size

        # to initialize from the final encoder state of last layer
        self.init_hidden_option = init_hidden
        if self.init_hidden_option == "bridge":
            self.bridge_layer = nn.Linear(encoder.output_size,
                                          hidden_size,
                                          bias=True)
        elif self.init_hidden_option == "last":
            if encoder.output_size != self.hidden_size:
                if encoder.output_size != 2 * self.hidden_size:  # bidirectional
                    raise ConfigurationError(
                        "For initializing the decoder state with the "
                        "last encoder state, their sizes have to match "
                        "(encoder: {} vs. decoder:  {})".format(
                            encoder.output_size, self.hidden_size))
        if freeze:
            freeze_params(self)

    def _check_shapes_input_forward_step(self, prev_embed: Tensor,
                                         prev_att_vector: Tensor,
                                         encoder_output: Tensor,
                                         src_mask: Tensor,
                                         hidden: Tensor) -> None:
        """
        Make sure the input shapes to `self._forward_step` are correct.
        Same inputs as `self._forward_step`.

        :param prev_embed:
        :param prev_att_vector:
        :param encoder_output:
        :param src_mask:
        :param hidden:
        """
        assert prev_embed.shape[1:] == torch.Size([1, self.emb_size])
        assert prev_att_vector.shape[1:] == torch.Size([1, self.hidden_size])
        assert prev_att_vector.shape[0] == prev_embed.shape[0]
        assert encoder_output.shape[0] == prev_embed.shape[0]
        assert len(encoder_output.shape) == 3
        assert src_mask.shape[0] == prev_embed.shape[0]
        assert src_mask.shape[1] == 1
        assert src_mask.shape[2] == encoder_output.shape[1]
        if isinstance(hidden, tuple):  # for lstm
            hidden = hidden[0]
        assert hidden.shape[0] == self.num_layers
        assert hidden.shape[1] == prev_embed.shape[0]
        assert hidden.shape[2] == self.hidden_size

    def _check_shapes_input_forward(self,
                                    trg_embed: Tensor,
                                    encoder_output: Tensor,
                                    encoder_hidden: Tensor,
                                    src_mask: Tensor,
                                    hidden: Tensor = None,
                                    prev_att_vector: Tensor = None) -> None:
        """
        Make sure that inputs to `self.forward` are of correct shape.
        Same input semantics as for `self.forward`.

        :param trg_embed:
        :param encoder_output:
        :param encoder_hidden:
        :param src_mask:
        :param hidden:
        :param prev_att_vector:
        """
        assert len(encoder_output.shape) == 3
        if encoder_hidden is not None:
            assert len(encoder_hidden.shape) == 2
            assert encoder_hidden.shape[-1] == encoder_output.shape[-1]
        assert src_mask.shape[1] == 1
        assert src_mask.shape[0] == encoder_output.shape[0]
        assert src_mask.shape[2] == encoder_output.shape[1]
        assert trg_embed.shape[0] == encoder_output.shape[0]
        assert trg_embed.shape[2] == self.emb_size
        if hidden is not None:
            if isinstance(hidden, tuple):  # for lstm
                hidden = hidden[0]
            assert hidden.shape[1] == encoder_output.shape[0]
            assert hidden.shape[2] == self.hidden_size
        if prev_att_vector is not None:
            assert prev_att_vector.shape[0] == encoder_output.shape[0]
            assert prev_att_vector.shape[2] == self.hidden_size
            assert prev_att_vector.shape[1] == 1

    def _forward_step(
            self,
            prev_embed: Tensor,
            prev_att_vector: Tensor,  # context or att vector
            encoder_output: Tensor,
            src_mask: Tensor,
            hidden: Tensor) -> (Tensor, Tensor, Tensor):
        """
        Perform a single decoder step (1 token).

        1. `rnn_input`: concat(prev_embed, prev_att_vector [possibly empty])
        2. update RNN with `rnn_input`
        3. calculate attention and context/attention vector

        :param prev_embed: embedded previous token,
            shape (batch_size, 1, embed_size)
        :param prev_att_vector: previous attention vector,
            shape (batch_size, 1, hidden_size)
        :param encoder_output: encoder hidden states for attention context,
            shape (batch_size, src_length, encoder.output_size)
        :param src_mask: src mask, 1s for area before <eos>, 0s elsewhere
            shape (batch_size, 1, src_length)
        :param hidden: previous hidden state,
            shape (num_layers, batch_size, hidden_size)
        :return:
            - att_vector: new attention vector (batch_size, 1, hidden_size),
            - hidden: new hidden state with shape (batch_size, 1, hidden_size),
            - att_probs: attention probabilities (batch_size, 1, src_len)
        """

        # shape checks
        self._check_shapes_input_forward_step(prev_embed=prev_embed,
                                              prev_att_vector=prev_att_vector,
                                              encoder_output=encoder_output,
                                              src_mask=src_mask,
                                              hidden=hidden)

        if self.input_feeding:
            # concatenate the input with the previous attention vector
            rnn_input = torch.cat([prev_embed, prev_att_vector], dim=2)
        else:
            rnn_input = prev_embed

        rnn_input = self.emb_dropout(rnn_input)

        # rnn_input: batch x 1 x emb+2*enc_size
        _, hidden = self.rnn(rnn_input, hidden)

        # use new (top) decoder layer as attention query
        if isinstance(hidden, tuple):
            query = hidden[0][-1].unsqueeze(1)
        else:
            query = hidden[-1].unsqueeze(1)  # [#layers, B, D] -> [B, 1, D]

        # compute context vector using attention mechanism
        # only use last layer for attention mechanism
        # key projections are pre-computed
        context, att_probs = self.attention(query=query,
                                            values=encoder_output,
                                            mask=src_mask)

        # return attention vector (Luong)
        # combine context with decoder hidden state before prediction
        att_vector_input = torch.cat([query, context], dim=2)
        # batch x 1 x 2*enc_size+hidden_size
        att_vector_input = self.hidden_dropout(att_vector_input)

        att_vector = torch.tanh(self.att_vector_layer(att_vector_input))

        # output: batch x 1 x hidden_size
        return att_vector, hidden, att_probs

    def forward(self,
                trg_embed: Tensor,
                encoder_output: Tensor,
                encoder_hidden: Tensor,
                src_mask: Tensor,
                unroll_steps: int,
                hidden: Tensor = None,
                prev_att_vector: Tensor = None,
                **kwargs) \
            -> (Tensor, Tensor, Tensor, Tensor):
        """
         Unroll the decoder one step at a time for `unroll_steps` steps.
         For every step, the `_forward_step` function is called internally.

         During training, the target inputs (`trg_embed') are already known for
         the full sequence, so the full unrol is done.
         In this case, `hidden` and `prev_att_vector` are None.

         For inference, this function is called with one step at a time since
         embedded targets are the predictions from the previous time step.
         In this case, `hidden` and `prev_att_vector` are fed from the output
         of the previous call of this function (from the 2nd step on).

         `src_mask` is needed to mask out the areas of the encoder states that
         should not receive any attention,
         which is everything after the first <eos>.

         The `encoder_output` are the hidden states from the encoder and are
         used as context for the attention.

         The `encoder_hidden` is the last encoder hidden state that is used to
         initialize the first hidden decoder state
         (when `self.init_hidden_option` is "bridge" or "last").

        :param trg_embed: embedded target inputs,
            shape (batch_size, trg_length, embed_size)
        :param encoder_output: hidden states from the encoder,
            shape (batch_size, src_length, encoder.output_size)
        :param encoder_hidden: last state from the encoder,
            shape (batch_size, encoder.output_size)
        :param src_mask: mask for src states: 0s for padded areas,
            1s for the rest, shape (batch_size, 1, src_length)
        :param unroll_steps: number of steps to unroll the decoder RNN
        :param hidden: previous decoder hidden state,
            if not given it's initialized as in `self.init_hidden`,
            shape (batch_size, num_layers, hidden_size)
        :param prev_att_vector: previous attentional vector,
            if not given it's initialized with zeros,
            shape (batch_size, 1, hidden_size)
        :return:
            - outputs: shape (batch_size, unroll_steps, vocab_size),
            - hidden: last hidden state (num_layers, batch_size, hidden_size),
            - att_probs: attention probabilities
                with shape (batch_size, unroll_steps, src_length),
            - att_vectors: attentional vectors
                with shape (batch_size, unroll_steps, hidden_size)
        """
        # initialize decoder hidden state from final encoder hidden state
        if hidden is None and encoder_hidden is not None:
            hidden = self._init_hidden(encoder_hidden)
        else:
            # DataParallel splits batch along the 0th dim.
            # Place back the batch_size to the 1st dim here.
            if isinstance(hidden, tuple):
                h, c = hidden
                hidden = (h.permute(1, 0,
                                    2).contiguous(), c.permute(1, 0,
                                                               2).contiguous())
            else:
                hidden = hidden.permute(1, 0, 2).contiguous()
            # shape (num_layers, batch_size, hidden_size)

        # shape checks
        self._check_shapes_input_forward(trg_embed=trg_embed,
                                         encoder_output=encoder_output,
                                         encoder_hidden=encoder_hidden,
                                         src_mask=src_mask,
                                         hidden=hidden,
                                         prev_att_vector=prev_att_vector)

        # pre-compute projected encoder outputs
        # (the "keys" for the attention mechanism)
        # this is only done for efficiency
        if hasattr(self.attention, "compute_proj_keys"):
            self.attention.compute_proj_keys(keys=encoder_output)

        # here we store all intermediate attention vectors (used for prediction)
        att_vectors = []
        att_probs = []

        batch_size = encoder_output.size(0)

        if prev_att_vector is None:
            with torch.no_grad():
                prev_att_vector = encoder_output.new_zeros(
                    [batch_size, 1, self.hidden_size])

        # unroll the decoder RNN for `unroll_steps` steps
        for i in range(unroll_steps):
            prev_embed = trg_embed[:, i].unsqueeze(1)  # batch, 1, emb
            prev_att_vector, hidden, att_prob = self._forward_step(
                prev_embed=prev_embed,
                prev_att_vector=prev_att_vector,
                encoder_output=encoder_output,
                src_mask=src_mask,
                hidden=hidden)
            att_vectors.append(prev_att_vector)
            att_probs.append(att_prob)

        att_vectors = torch.cat(att_vectors, dim=1)
        # att_vectors: batch, unroll_steps, hidden_size
        att_probs = torch.cat(att_probs, dim=1)
        # att_probs: batch, unroll_steps, src_length
        outputs = self.output_layer(att_vectors)
        # outputs: batch, unroll_steps, vocab_size

        # DataParallel gathers batches along the 0th dim.
        # Put batch_size dim to the 0th position.
        if isinstance(hidden, tuple):
            h, c = hidden
            hidden = (h.permute(1, 0,
                                2).contiguous(), c.permute(1, 0,
                                                           2).contiguous())
            assert hidden[0].size(0) == batch_size
        else:
            hidden = hidden.permute(1, 0, 2).contiguous()
            assert hidden.size(0) == batch_size
        # shape (batch_size, num_layers, hidden_size)
        return outputs, hidden, att_probs, att_vectors

    def _init_hidden(self, encoder_final: Tensor = None) \
            -> (Tensor, Optional[Tensor]):
        """
        Returns the initial decoder state,
        conditioned on the final encoder state of the last encoder layer.

        In case of `self.init_hidden_option == "bridge"`
        and a given `encoder_final`, this is a projection of the encoder state.

        In case of `self.init_hidden_option == "last"`
        and a size-matching `encoder_final`, this is set to the encoder state.
        If the encoder is twice as large as the decoder state (e.g. when
        bi-directional), just use the forward hidden state.

        In case of `self.init_hidden_option == "zero"`, it is initialized with
        zeros.

        For LSTMs we initialize both the hidden state and the memory cell
        with the same projection/copy of the encoder hidden state.

        All decoder layers are initialized with the same initial values.

        :param encoder_final: final state from the last layer of the encoder,
            shape (batch_size, encoder_hidden_size)
        :return: hidden state if GRU, (hidden state, memory cell) if LSTM,
            shape (batch_size, hidden_size)
        """
        batch_size = encoder_final.size(0)

        # for multiple layers: is the same for all layers
        if self.init_hidden_option == "bridge" and encoder_final is not None:
            # num_layers x batch_size x hidden_size
            hidden = torch.tanh(
                self.bridge_layer(encoder_final)).unsqueeze(0).repeat(
                    self.num_layers, 1, 1)
        elif self.init_hidden_option == "last" and encoder_final is not None:
            # special case: encoder is bidirectional: use only forward state
            if encoder_final.shape[1] == 2 * self.hidden_size:  # bidirectional
                encoder_final = encoder_final[:, :self.hidden_size]
            hidden = encoder_final.unsqueeze(0).repeat(self.num_layers, 1, 1)
        else:  # initialize with zeros
            with torch.no_grad():
                hidden = encoder_final.new_zeros(self.num_layers, batch_size,
                                                 self.hidden_size)

        return (hidden, hidden) if isinstance(self.rnn, nn.LSTM) else hidden

    def __repr__(self):
        return "RecurrentDecoder(rnn=%r, attention=%r)" % (self.rnn,
                                                           self.attention)
class TestLuongAttention(TensorTestCase):

    def setUp(self):
        self.addTypeEqualityFunc(torch.Tensor,
                                 lambda x, y, msg: self.failureException(
                                     msg) if not torch.equal(x, y) else True)
        self.key_size = 3
        self.query_size = 5
        self.hidden_size = self.query_size
        seed = 42
        torch.manual_seed(seed)
        self.luong_att = LuongAttention(hidden_size=self.hidden_size,
                                        key_size=self.key_size)

    def test_luong_attention_size(self):
        self.assertIsNone(self.luong_att.key_layer.bias)  # no bias
        self.assertEqual(self.luong_att.key_layer.weight.shape,
                         torch.Size([self.hidden_size, self.key_size]))

    def test_luong_attention_forward(self):
        src_length = 5
        trg_length = 4
        batch_size = 6
        queries = torch.rand(size=(batch_size, trg_length, self.query_size))
        keys = torch.rand(size=(batch_size, src_length, self.key_size))
        mask = torch.ones(size=(batch_size, 1, src_length)).byte()
        # introduce artificial padding areas
        mask[0, 0, -3:] = 0
        mask[1, 0, -2:] = 0
        mask[4, 0, -1:] = 0
        for t in range(trg_length):
            c, att = None, None
            try:
                # should raise an AssertionException (missing pre-computation)
                query = queries[:, t, :].unsqueeze(1)
                c, att = self.luong_att(query=query, mask=mask, values=keys)
            except AssertionError:
                pass
            self.assertIsNone(c)
            self.assertIsNone(att)

        # now with pre-computation
        self.luong_att.compute_proj_keys(keys=keys)
        self.assertIsNotNone(self.luong_att.proj_keys)
        self.assertEqual(self.luong_att.proj_keys.shape,
                         torch.Size([batch_size, src_length, self.hidden_size]))
        contexts = []
        attention_probs = []
        for t in range(trg_length):
            c, att = None, None
            try:
                # should not raise an AssertionException
                query = queries[:, t, :].unsqueeze(1)
                c, att = self.luong_att(query=query, mask=mask, values=keys)
            except AssertionError:
                self.fail()
            self.assertIsNotNone(c)
            self.assertIsNotNone(att)
            contexts.append(c)
            attention_probs.append(att)
        self.assertEqual(len(attention_probs), trg_length)
        self.assertEqual(len(contexts), trg_length)
        contexts = torch.cat(contexts, dim=1)
        attention_probs = torch.cat(attention_probs, dim=1)
        self.assertEqual(contexts.shape,
                         torch.Size([batch_size, trg_length, self.key_size]))
        self.assertEqual(attention_probs.shape,
                         torch.Size([batch_size, trg_length, src_length]))
        context_targets = torch.Tensor([[[0.5347, 0.2918, 0.4707],
         [0.5062, 0.2657, 0.4117],
         [0.4969, 0.2572, 0.3926],
         [0.5320, 0.2893, 0.4651]],

        [[0.5210, 0.6707, 0.4343],
         [0.5111, 0.6809, 0.4274],
         [0.5156, 0.6622, 0.4274],
         [0.5046, 0.6634, 0.4175]],

        [[0.4998, 0.5570, 0.3388],
         [0.4949, 0.5357, 0.3609],
         [0.4982, 0.5208, 0.3468],
         [0.5013, 0.5474, 0.3503]],

        [[0.5911, 0.6944, 0.5319],
         [0.5964, 0.6899, 0.5257],
         [0.6161, 0.6771, 0.5042],
         [0.5937, 0.7011, 0.5330]],

        [[0.4439, 0.5916, 0.3691],
         [0.4409, 0.5970, 0.3762],
         [0.4446, 0.5845, 0.3659],
         [0.4417, 0.6157, 0.3796]],

        [[0.4581, 0.4343, 0.5151],
         [0.4493, 0.4297, 0.5348],
         [0.4399, 0.4265, 0.5419],
         [0.4833, 0.4570, 0.4855]]])
        self.assertTensorAlmostEqual(context_targets, contexts)
        attention_probs_targets = torch.Tensor(
            [[[0.3238, 0.6762, 0.0000, 0.0000, 0.0000],
              [0.4090, 0.5910, 0.0000, 0.0000, 0.0000],
              [0.4367, 0.5633, 0.0000, 0.0000, 0.0000],
              [0.3319, 0.6681, 0.0000, 0.0000, 0.0000]],

             [[0.2483, 0.3291, 0.4226, 0.0000, 0.0000],
              [0.2353, 0.3474, 0.4174, 0.0000, 0.0000],
              [0.2725, 0.3322, 0.3953, 0.0000, 0.0000],
              [0.2803, 0.3476, 0.3721, 0.0000, 0.0000]],

             [[0.1955, 0.1516, 0.2518, 0.1466, 0.2546],
              [0.2220, 0.1613, 0.2402, 0.1462, 0.2303],
              [0.2074, 0.1953, 0.2142, 0.1536, 0.2296],
              [0.2100, 0.1615, 0.2434, 0.1376, 0.2475]],

             [[0.2227, 0.2483, 0.1512, 0.1486, 0.2291],
              [0.2210, 0.2331, 0.1599, 0.1542, 0.2318],
              [0.2123, 0.1808, 0.1885, 0.1702, 0.2482],
              [0.2233, 0.2479, 0.1435, 0.1433, 0.2421]],

             [[0.2475, 0.2482, 0.2865, 0.2178, 0.0000],
              [0.2494, 0.2410, 0.2976, 0.2120, 0.0000],
              [0.2498, 0.2449, 0.2778, 0.2275, 0.0000],
              [0.2359, 0.2603, 0.3174, 0.1864, 0.0000]],

             [[0.2362, 0.1929, 0.2128, 0.1859, 0.1723],
              [0.2230, 0.2118, 0.2116, 0.1890, 0.1646],
              [0.2118, 0.2251, 0.2039, 0.1891, 0.1700],
              [0.2859, 0.1874, 0.2083, 0.1583, 0.1601]]])
        self.assertTensorAlmostEqual(attention_probs_targets, attention_probs)

    def test_luong_precompute_None(self):
        self.assertIsNone(self.luong_att.proj_keys)

    def test_luong_precompute(self):
        src_length = 5
        batch_size = 6
        keys = torch.rand(size=(batch_size, src_length, self.key_size))
        self.luong_att.compute_proj_keys(keys=keys)
        proj_keys_targets = torch.Tensor(
            [[[0.5362, 0.1826, 0.4716, 0.3245, 0.4122],
              [0.3819, 0.0934, 0.2750, 0.2311, 0.2378],
              [0.2246, 0.2934, 0.3999, 0.0519, 0.4430],
              [0.1271, 0.0636, 0.2444, 0.1294, 0.1659],
              [0.3494, 0.0372, 0.1326, 0.1908, 0.1295]],

             [[0.3363, 0.5984, 0.2090, -0.2695, 0.6584],
              [0.3098, 0.3608, 0.3623, 0.0098, 0.5004],
              [0.6133, 0.2568, 0.4264, 0.2688, 0.4716],
              [0.4058, 0.1438, 0.3043, 0.2127, 0.2971],
              [0.6604, 0.3490, 0.5228, 0.2593, 0.5967]],

             [[0.4224, 0.1182, 0.4883, 0.3403, 0.3458],
              [0.4257, 0.3757, -0.1431, -0.2208, 0.3383],
              [0.0681, 0.2540, 0.4165, 0.0269, 0.3934],
              [0.5341, 0.3288, 0.3937, 0.1532, 0.5132],
              [0.6244, 0.1647, 0.2378, 0.2548, 0.3196]],

             [[0.2222, 0.3380, 0.2374, -0.0748, 0.4212],
              [0.4042, 0.1373, 0.3308, 0.2317, 0.3011],
              [0.4740, 0.4829, -0.0853, -0.2634, 0.4623],
              [0.4540, 0.0645, 0.6046, 0.4632, 0.3459],
              [0.4744, 0.5098, -0.2441, -0.3713, 0.4265]],

             [[0.0314, 0.1189, 0.3825, 0.1119, 0.2548],
              [0.7057, 0.2725, 0.2426, 0.1979, 0.4285],
              [0.3967, 0.0223, 0.3664, 0.3488, 0.2107],
              [0.4311, 0.4695, 0.3035, -0.0640, 0.5914],
              [0.0797, 0.1038, 0.3847, 0.1476, 0.2486]],

             [[0.3379, 0.3671, 0.3622, 0.0166, 0.5097],
              [0.4051, 0.4552, -0.0709, -0.2616, 0.4339],
              [0.5379, 0.5037, 0.0074, -0.2046, 0.5243],
              [0.0250, 0.0544, 0.3859, 0.1679, 0.1976],
              [0.1880, 0.2725, 0.1849, -0.0598, 0.3383]]]
        )
        self.assertTensorAlmostEqual(proj_keys_targets, self.luong_att.proj_keys)