class RecurrentDecoder(Decoder): """A conditional RNN decoder with attention.""" def __init__(self, type: str = "gru", emb_size: int = 0, hidden_size: int = 0, encoder: Encoder = None, attention: str = "bahdanau", num_layers: int = 0, vocab_size: int = 0, dropout: float = 0., hidden_dropout: float = 0., bridge: bool = False, input_feeding: bool = True, freeze: bool = False, **kwargs): """ Create a recurrent decoder. If `bridge` is True, the decoder hidden states are initialized from a projection of the encoder states, else they are initialized with zeros. :param type: :param emb_size: :param hidden_size: :param encoder: :param attention: :param num_layers: :param vocab_size: :param dropout: :param hidden_dropout: :param bridge: :param input_feeding: :param freeze: freeze the parameters of the decoder during training :param kwargs: """ super(RecurrentDecoder, self).__init__() self.rnn_input_dropout = torch.nn.Dropout(p=dropout, inplace=False) self.type = type self.hidden_dropout = torch.nn.Dropout(p=hidden_dropout, inplace=False) self.hidden_size = hidden_size rnn = nn.GRU if type == "gru" else nn.LSTM self.input_feeding = input_feeding if self.input_feeding: # Luong-style # combine embedded prev word +attention vector before feeding to rnn self.rnn_input_size = emb_size + hidden_size else: # just feed prev word embedding self.rnn_input_size = emb_size # the decoder RNN self.rnn = rnn(self.rnn_input_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.) # combine output with context vector before output layer (Luong-style) self.att_vector_layer = nn.Linear( hidden_size + encoder.output_size, hidden_size, bias=True) self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False) self.output_size = vocab_size if attention == "bahdanau": self.attention = BahdanauAttention(hidden_size=hidden_size, key_size=encoder.output_size, query_size=hidden_size) elif attention == "luong": self.attention = LuongAttention(hidden_size=hidden_size, key_size=encoder.output_size) else: raise ValueError("Unknown attention mechanism: %s" % attention) self.num_layers = num_layers self.hidden_size = hidden_size # to initialize from the final encoder state of last layer self.bridge = bridge if self.bridge: self.bridge_layer = nn.Linear( encoder.output_size, hidden_size, bias=True) if freeze: freeze_params(self) def _forward_step(self, prev_embed: Tensor = None, prev_att_vector: Tensor = None, # context or att vector encoder_output: Tensor = None, src_mask: Tensor = None, hidden: Tensor = None): """ Perform a single decoder step (1 word) :param prev_embed: :param prev_att_vector: :param encoder_output: :param src_mask: :param hidden: :return: """ # loop: # 1. rnn input = concat(prev_embed, prev_output [possibly empty]) # 2. update RNN with rnn_input # 3. calculate attention and context/attention vector # 4. repeat # update rnn hidden state if self.input_feeding: rnn_input = torch.cat([prev_embed, prev_att_vector], dim=2) else: rnn_input = prev_embed rnn_input = self.rnn_input_dropout(rnn_input) # rnn_input: batch x 1 x emb+2*enc_size _, hidden = self.rnn(rnn_input, hidden) # use new (top) decoder layer as attention query if isinstance(hidden, tuple): query = hidden[0][-1].unsqueeze(1) else: query = hidden[-1].unsqueeze(1) # [#layers, B, D] -> [B, 1, D] # compute context vector using attention mechanism # only use last layer for attention mechanism # key projections are pre-computed context, att_probs = self.attention( query=query, values=encoder_output, mask=src_mask) # return attention vector (Luong) # combine context with decoder hidden state before prediction att_vector_input = torch.cat([query, context], dim=2) att_vector_input = self.hidden_dropout(att_vector_input) # batch x 1 x 2*enc_size+hidden_size att_vector = torch.tanh(self.att_vector_layer(att_vector_input)) # output: batch x 1 x dec_size return att_vector, hidden, att_probs def forward(self, trg_embed, encoder_output, encoder_hidden, src_mask, unrol_steps, hidden=None, prev_att_vector=None): """ Unroll the decoder one step at a time for `unrol_steps` steps. :param trg_embed: :param encoder_output: :param encoder_hidden: :param src_mask: :param unrol_steps: :param hidden: :param prev_att_vector: :return: """ # initialize decoder hidden state from final encoder hidden state if hidden is None: hidden = self.init_hidden(encoder_hidden) # pre-compute projected encoder outputs # (the "keys" for the attention mechanism) # this is only done for efficiency if hasattr(self.attention, "compute_proj_keys"): self.attention.compute_proj_keys(encoder_output) # here we store all intermediate attention vectors (used for prediction) att_vectors = [] att_probs = [] batch_size = encoder_output.size(0) if prev_att_vector is None: with torch.no_grad(): prev_att_vector = encoder_output.new_zeros( [batch_size, 1, self.hidden_size]) # unroll the decoder RN N for max_len steps for i in range(unrol_steps): prev_embed = trg_embed[:, i].unsqueeze(1) # batch, 1, emb prev_att_vector, hidden, att_prob = self._forward_step( prev_embed=prev_embed, prev_att_vector=prev_att_vector, encoder_output=encoder_output, src_mask=src_mask, hidden=hidden) att_vectors.append(prev_att_vector) att_probs.append(att_prob) att_vectors = torch.cat(att_vectors, dim=1) att_probs = torch.cat(att_probs, dim=1) # att_probs: batch, max_len, src_length outputs = self.output_layer(att_vectors) # outputs: batch, max_len, vocab_size return outputs, hidden, att_probs, att_vectors def init_hidden(self, encoder_final): """ Returns the initial decoder state, conditioned on the final encoder state of the last encoder layer. :param encoder_final: :return: """ batch_size = encoder_final.size(0) # for multiple layers: is the same for all layers if self.bridge and encoder_final is not None: h = torch.tanh( self.bridge_layer(encoder_final)).unsqueeze(0).repeat( self.num_layers, 1, 1) # num_layers x batch_size x hidden_size else: # initialize with zeros with torch.no_grad(): h = encoder_final.new_zeros(self.num_layers, batch_size, self.hidden_size) return (h, h) if isinstance(self.rnn, nn.LSTM) else h def __repr__(self): return "RecurrentDecoder(rnn=%r, attention=%r)" % ( self.rnn, self.attention)
class RecurrentDecoder(Decoder): """A conditional RNN decoder with attention.""" def __init__(self, rnn_type: str = "gru", emb_size: int = 0, hidden_size: int = 0, encoder: Encoder = None, attention: str = "bahdanau", num_layers: int = 1, vocab_size: int = 0, dropout: float = 0., emb_dropout: float = 0., hidden_dropout: float = 0., init_hidden: str = "bridge", input_feeding: bool = True, freeze: bool = False, **kwargs) -> None: """ Create a recurrent decoder with attention. :param rnn_type: rnn type, valid options: "lstm", "gru" :param emb_size: target embedding size :param hidden_size: size of the RNN :param encoder: encoder connected to this decoder :param attention: type of attention, valid options: "bahdanau", "luong" :param num_layers: number of recurrent layers :param vocab_size: target vocabulary size :param hidden_dropout: Is applied to the input to the attentional layer. :param dropout: Is applied between RNN layers. :param emb_dropout: Is applied to the RNN input (word embeddings). :param init_hidden: If "bridge" (default), the decoder hidden states are initialized from a projection of the last encoder state, if "zeros" they are initialized with zeros, if "last" they are identical to the last encoder state (only if they have the same size) :param input_feeding: Use Luong's input feeding. :param freeze: Freeze the parameters of the decoder during training. :param kwargs: """ super().__init__() self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False) self.type = rnn_type self.hidden_dropout = torch.nn.Dropout(p=hidden_dropout, inplace=False) self.hidden_size = hidden_size self.emb_size = emb_size rnn = nn.GRU if rnn_type == "gru" else nn.LSTM self.input_feeding = input_feeding if self.input_feeding: # Luong-style # combine embedded prev word +attention vector before feeding to rnn self.rnn_input_size = emb_size + hidden_size else: # just feed prev word embedding self.rnn_input_size = emb_size # the decoder RNN self.rnn = rnn(self.rnn_input_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.) # combine output with context vector before output layer (Luong-style) self.att_vector_layer = nn.Linear(hidden_size + encoder.output_size, hidden_size, bias=True) self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False) self._output_size = vocab_size if attention == "bahdanau": self.attention = BahdanauAttention(hidden_size=hidden_size, key_size=encoder.output_size, query_size=hidden_size) elif attention == "luong": self.attention = LuongAttention(hidden_size=hidden_size, key_size=encoder.output_size) else: raise ConfigurationError("Unknown attention mechanism: %s. " "Valid options: 'bahdanau', 'luong'." % attention) self.num_layers = num_layers self.hidden_size = hidden_size # to initialize from the final encoder state of last layer self.init_hidden_option = init_hidden if self.init_hidden_option == "bridge": self.bridge_layer = nn.Linear(encoder.output_size, hidden_size, bias=True) elif self.init_hidden_option == "last": if encoder.output_size != self.hidden_size: if encoder.output_size != 2 * self.hidden_size: # bidirectional raise ConfigurationError( "For initializing the decoder state with the " "last encoder state, their sizes have to match " "(encoder: {} vs. decoder: {})".format( encoder.output_size, self.hidden_size)) if freeze: freeze_params(self) def _check_shapes_input_forward_step(self, prev_embed: Tensor, prev_att_vector: Tensor, encoder_output: Tensor, src_mask: Tensor, hidden: Tensor) -> None: """ Make sure the input shapes to `self._forward_step` are correct. Same inputs as `self._forward_step`. :param prev_embed: :param prev_att_vector: :param encoder_output: :param src_mask: :param hidden: """ assert prev_embed.shape[1:] == torch.Size([1, self.emb_size]) assert prev_att_vector.shape[1:] == torch.Size([1, self.hidden_size]) assert prev_att_vector.shape[0] == prev_embed.shape[0] assert encoder_output.shape[0] == prev_embed.shape[0] assert len(encoder_output.shape) == 3 assert src_mask.shape[0] == prev_embed.shape[0] assert src_mask.shape[1] == 1 assert src_mask.shape[2] == encoder_output.shape[1] if isinstance(hidden, tuple): # for lstm hidden = hidden[0] assert hidden.shape[0] == self.num_layers assert hidden.shape[1] == prev_embed.shape[0] assert hidden.shape[2] == self.hidden_size def _check_shapes_input_forward(self, trg_embed: Tensor, encoder_output: Tensor, encoder_hidden: Tensor, src_mask: Tensor, hidden: Tensor = None, prev_att_vector: Tensor = None) -> None: """ Make sure that inputs to `self.forward` are of correct shape. Same input semantics as for `self.forward`. :param trg_embed: :param encoder_output: :param encoder_hidden: :param src_mask: :param hidden: :param prev_att_vector: """ assert len(encoder_output.shape) == 3 if encoder_hidden is not None: assert len(encoder_hidden.shape) == 2 assert encoder_hidden.shape[-1] == encoder_output.shape[-1] assert src_mask.shape[1] == 1 assert src_mask.shape[0] == encoder_output.shape[0] assert src_mask.shape[2] == encoder_output.shape[1] assert trg_embed.shape[0] == encoder_output.shape[0] assert trg_embed.shape[2] == self.emb_size if hidden is not None: if isinstance(hidden, tuple): # for lstm hidden = hidden[0] assert hidden.shape[1] == encoder_output.shape[0] assert hidden.shape[2] == self.hidden_size if prev_att_vector is not None: assert prev_att_vector.shape[0] == encoder_output.shape[0] assert prev_att_vector.shape[2] == self.hidden_size assert prev_att_vector.shape[1] == 1 def _forward_step( self, prev_embed: Tensor, prev_att_vector: Tensor, # context or att vector encoder_output: Tensor, src_mask: Tensor, hidden: Tensor) -> (Tensor, Tensor, Tensor): """ Perform a single decoder step (1 token). 1. `rnn_input`: concat(prev_embed, prev_att_vector [possibly empty]) 2. update RNN with `rnn_input` 3. calculate attention and context/attention vector :param prev_embed: embedded previous token, shape (batch_size, 1, embed_size) :param prev_att_vector: previous attention vector, shape (batch_size, 1, hidden_size) :param encoder_output: encoder hidden states for attention context, shape (batch_size, src_length, encoder.output_size) :param src_mask: src mask, 1s for area before <eos>, 0s elsewhere shape (batch_size, 1, src_length) :param hidden: previous hidden state, shape (num_layers, batch_size, hidden_size) :return: - att_vector: new attention vector (batch_size, 1, hidden_size), - hidden: new hidden state with shape (batch_size, 1, hidden_size), - att_probs: attention probabilities (batch_size, 1, src_len) """ # shape checks self._check_shapes_input_forward_step(prev_embed=prev_embed, prev_att_vector=prev_att_vector, encoder_output=encoder_output, src_mask=src_mask, hidden=hidden) if self.input_feeding: # concatenate the input with the previous attention vector rnn_input = torch.cat([prev_embed, prev_att_vector], dim=2) else: rnn_input = prev_embed rnn_input = self.emb_dropout(rnn_input) # rnn_input: batch x 1 x emb+2*enc_size _, hidden = self.rnn(rnn_input, hidden) # use new (top) decoder layer as attention query if isinstance(hidden, tuple): query = hidden[0][-1].unsqueeze(1) else: query = hidden[-1].unsqueeze(1) # [#layers, B, D] -> [B, 1, D] # compute context vector using attention mechanism # only use last layer for attention mechanism # key projections are pre-computed context, att_probs = self.attention(query=query, values=encoder_output, mask=src_mask) # return attention vector (Luong) # combine context with decoder hidden state before prediction att_vector_input = torch.cat([query, context], dim=2) # batch x 1 x 2*enc_size+hidden_size att_vector_input = self.hidden_dropout(att_vector_input) att_vector = torch.tanh(self.att_vector_layer(att_vector_input)) # output: batch x 1 x hidden_size return att_vector, hidden, att_probs def forward(self, trg_embed: Tensor, encoder_output: Tensor, encoder_hidden: Tensor, src_mask: Tensor, unroll_steps: int, hidden: Tensor = None, prev_att_vector: Tensor = None, **kwargs) \ -> (Tensor, Tensor, Tensor, Tensor): """ Unroll the decoder one step at a time for `unroll_steps` steps. For every step, the `_forward_step` function is called internally. During training, the target inputs (`trg_embed') are already known for the full sequence, so the full unrol is done. In this case, `hidden` and `prev_att_vector` are None. For inference, this function is called with one step at a time since embedded targets are the predictions from the previous time step. In this case, `hidden` and `prev_att_vector` are fed from the output of the previous call of this function (from the 2nd step on). `src_mask` is needed to mask out the areas of the encoder states that should not receive any attention, which is everything after the first <eos>. The `encoder_output` are the hidden states from the encoder and are used as context for the attention. The `encoder_hidden` is the last encoder hidden state that is used to initialize the first hidden decoder state (when `self.init_hidden_option` is "bridge" or "last"). :param trg_embed: embedded target inputs, shape (batch_size, trg_length, embed_size) :param encoder_output: hidden states from the encoder, shape (batch_size, src_length, encoder.output_size) :param encoder_hidden: last state from the encoder, shape (batch_size, encoder.output_size) :param src_mask: mask for src states: 0s for padded areas, 1s for the rest, shape (batch_size, 1, src_length) :param unroll_steps: number of steps to unroll the decoder RNN :param hidden: previous decoder hidden state, if not given it's initialized as in `self.init_hidden`, shape (batch_size, num_layers, hidden_size) :param prev_att_vector: previous attentional vector, if not given it's initialized with zeros, shape (batch_size, 1, hidden_size) :return: - outputs: shape (batch_size, unroll_steps, vocab_size), - hidden: last hidden state (num_layers, batch_size, hidden_size), - att_probs: attention probabilities with shape (batch_size, unroll_steps, src_length), - att_vectors: attentional vectors with shape (batch_size, unroll_steps, hidden_size) """ # initialize decoder hidden state from final encoder hidden state if hidden is None and encoder_hidden is not None: hidden = self._init_hidden(encoder_hidden) else: # DataParallel splits batch along the 0th dim. # Place back the batch_size to the 1st dim here. if isinstance(hidden, tuple): h, c = hidden hidden = (h.permute(1, 0, 2).contiguous(), c.permute(1, 0, 2).contiguous()) else: hidden = hidden.permute(1, 0, 2).contiguous() # shape (num_layers, batch_size, hidden_size) # shape checks self._check_shapes_input_forward(trg_embed=trg_embed, encoder_output=encoder_output, encoder_hidden=encoder_hidden, src_mask=src_mask, hidden=hidden, prev_att_vector=prev_att_vector) # pre-compute projected encoder outputs # (the "keys" for the attention mechanism) # this is only done for efficiency if hasattr(self.attention, "compute_proj_keys"): self.attention.compute_proj_keys(keys=encoder_output) # here we store all intermediate attention vectors (used for prediction) att_vectors = [] att_probs = [] batch_size = encoder_output.size(0) if prev_att_vector is None: with torch.no_grad(): prev_att_vector = encoder_output.new_zeros( [batch_size, 1, self.hidden_size]) # unroll the decoder RNN for `unroll_steps` steps for i in range(unroll_steps): prev_embed = trg_embed[:, i].unsqueeze(1) # batch, 1, emb prev_att_vector, hidden, att_prob = self._forward_step( prev_embed=prev_embed, prev_att_vector=prev_att_vector, encoder_output=encoder_output, src_mask=src_mask, hidden=hidden) att_vectors.append(prev_att_vector) att_probs.append(att_prob) att_vectors = torch.cat(att_vectors, dim=1) # att_vectors: batch, unroll_steps, hidden_size att_probs = torch.cat(att_probs, dim=1) # att_probs: batch, unroll_steps, src_length outputs = self.output_layer(att_vectors) # outputs: batch, unroll_steps, vocab_size # DataParallel gathers batches along the 0th dim. # Put batch_size dim to the 0th position. if isinstance(hidden, tuple): h, c = hidden hidden = (h.permute(1, 0, 2).contiguous(), c.permute(1, 0, 2).contiguous()) assert hidden[0].size(0) == batch_size else: hidden = hidden.permute(1, 0, 2).contiguous() assert hidden.size(0) == batch_size # shape (batch_size, num_layers, hidden_size) return outputs, hidden, att_probs, att_vectors def _init_hidden(self, encoder_final: Tensor = None) \ -> (Tensor, Optional[Tensor]): """ Returns the initial decoder state, conditioned on the final encoder state of the last encoder layer. In case of `self.init_hidden_option == "bridge"` and a given `encoder_final`, this is a projection of the encoder state. In case of `self.init_hidden_option == "last"` and a size-matching `encoder_final`, this is set to the encoder state. If the encoder is twice as large as the decoder state (e.g. when bi-directional), just use the forward hidden state. In case of `self.init_hidden_option == "zero"`, it is initialized with zeros. For LSTMs we initialize both the hidden state and the memory cell with the same projection/copy of the encoder hidden state. All decoder layers are initialized with the same initial values. :param encoder_final: final state from the last layer of the encoder, shape (batch_size, encoder_hidden_size) :return: hidden state if GRU, (hidden state, memory cell) if LSTM, shape (batch_size, hidden_size) """ batch_size = encoder_final.size(0) # for multiple layers: is the same for all layers if self.init_hidden_option == "bridge" and encoder_final is not None: # num_layers x batch_size x hidden_size hidden = torch.tanh( self.bridge_layer(encoder_final)).unsqueeze(0).repeat( self.num_layers, 1, 1) elif self.init_hidden_option == "last" and encoder_final is not None: # special case: encoder is bidirectional: use only forward state if encoder_final.shape[1] == 2 * self.hidden_size: # bidirectional encoder_final = encoder_final[:, :self.hidden_size] hidden = encoder_final.unsqueeze(0).repeat(self.num_layers, 1, 1) else: # initialize with zeros with torch.no_grad(): hidden = encoder_final.new_zeros(self.num_layers, batch_size, self.hidden_size) return (hidden, hidden) if isinstance(self.rnn, nn.LSTM) else hidden def __repr__(self): return "RecurrentDecoder(rnn=%r, attention=%r)" % (self.rnn, self.attention)
class TestLuongAttention(TensorTestCase): def setUp(self): self.addTypeEqualityFunc(torch.Tensor, lambda x, y, msg: self.failureException( msg) if not torch.equal(x, y) else True) self.key_size = 3 self.query_size = 5 self.hidden_size = self.query_size seed = 42 torch.manual_seed(seed) self.luong_att = LuongAttention(hidden_size=self.hidden_size, key_size=self.key_size) def test_luong_attention_size(self): self.assertIsNone(self.luong_att.key_layer.bias) # no bias self.assertEqual(self.luong_att.key_layer.weight.shape, torch.Size([self.hidden_size, self.key_size])) def test_luong_attention_forward(self): src_length = 5 trg_length = 4 batch_size = 6 queries = torch.rand(size=(batch_size, trg_length, self.query_size)) keys = torch.rand(size=(batch_size, src_length, self.key_size)) mask = torch.ones(size=(batch_size, 1, src_length)).byte() # introduce artificial padding areas mask[0, 0, -3:] = 0 mask[1, 0, -2:] = 0 mask[4, 0, -1:] = 0 for t in range(trg_length): c, att = None, None try: # should raise an AssertionException (missing pre-computation) query = queries[:, t, :].unsqueeze(1) c, att = self.luong_att(query=query, mask=mask, values=keys) except AssertionError: pass self.assertIsNone(c) self.assertIsNone(att) # now with pre-computation self.luong_att.compute_proj_keys(keys=keys) self.assertIsNotNone(self.luong_att.proj_keys) self.assertEqual(self.luong_att.proj_keys.shape, torch.Size([batch_size, src_length, self.hidden_size])) contexts = [] attention_probs = [] for t in range(trg_length): c, att = None, None try: # should not raise an AssertionException query = queries[:, t, :].unsqueeze(1) c, att = self.luong_att(query=query, mask=mask, values=keys) except AssertionError: self.fail() self.assertIsNotNone(c) self.assertIsNotNone(att) contexts.append(c) attention_probs.append(att) self.assertEqual(len(attention_probs), trg_length) self.assertEqual(len(contexts), trg_length) contexts = torch.cat(contexts, dim=1) attention_probs = torch.cat(attention_probs, dim=1) self.assertEqual(contexts.shape, torch.Size([batch_size, trg_length, self.key_size])) self.assertEqual(attention_probs.shape, torch.Size([batch_size, trg_length, src_length])) context_targets = torch.Tensor([[[0.5347, 0.2918, 0.4707], [0.5062, 0.2657, 0.4117], [0.4969, 0.2572, 0.3926], [0.5320, 0.2893, 0.4651]], [[0.5210, 0.6707, 0.4343], [0.5111, 0.6809, 0.4274], [0.5156, 0.6622, 0.4274], [0.5046, 0.6634, 0.4175]], [[0.4998, 0.5570, 0.3388], [0.4949, 0.5357, 0.3609], [0.4982, 0.5208, 0.3468], [0.5013, 0.5474, 0.3503]], [[0.5911, 0.6944, 0.5319], [0.5964, 0.6899, 0.5257], [0.6161, 0.6771, 0.5042], [0.5937, 0.7011, 0.5330]], [[0.4439, 0.5916, 0.3691], [0.4409, 0.5970, 0.3762], [0.4446, 0.5845, 0.3659], [0.4417, 0.6157, 0.3796]], [[0.4581, 0.4343, 0.5151], [0.4493, 0.4297, 0.5348], [0.4399, 0.4265, 0.5419], [0.4833, 0.4570, 0.4855]]]) self.assertTensorAlmostEqual(context_targets, contexts) attention_probs_targets = torch.Tensor( [[[0.3238, 0.6762, 0.0000, 0.0000, 0.0000], [0.4090, 0.5910, 0.0000, 0.0000, 0.0000], [0.4367, 0.5633, 0.0000, 0.0000, 0.0000], [0.3319, 0.6681, 0.0000, 0.0000, 0.0000]], [[0.2483, 0.3291, 0.4226, 0.0000, 0.0000], [0.2353, 0.3474, 0.4174, 0.0000, 0.0000], [0.2725, 0.3322, 0.3953, 0.0000, 0.0000], [0.2803, 0.3476, 0.3721, 0.0000, 0.0000]], [[0.1955, 0.1516, 0.2518, 0.1466, 0.2546], [0.2220, 0.1613, 0.2402, 0.1462, 0.2303], [0.2074, 0.1953, 0.2142, 0.1536, 0.2296], [0.2100, 0.1615, 0.2434, 0.1376, 0.2475]], [[0.2227, 0.2483, 0.1512, 0.1486, 0.2291], [0.2210, 0.2331, 0.1599, 0.1542, 0.2318], [0.2123, 0.1808, 0.1885, 0.1702, 0.2482], [0.2233, 0.2479, 0.1435, 0.1433, 0.2421]], [[0.2475, 0.2482, 0.2865, 0.2178, 0.0000], [0.2494, 0.2410, 0.2976, 0.2120, 0.0000], [0.2498, 0.2449, 0.2778, 0.2275, 0.0000], [0.2359, 0.2603, 0.3174, 0.1864, 0.0000]], [[0.2362, 0.1929, 0.2128, 0.1859, 0.1723], [0.2230, 0.2118, 0.2116, 0.1890, 0.1646], [0.2118, 0.2251, 0.2039, 0.1891, 0.1700], [0.2859, 0.1874, 0.2083, 0.1583, 0.1601]]]) self.assertTensorAlmostEqual(attention_probs_targets, attention_probs) def test_luong_precompute_None(self): self.assertIsNone(self.luong_att.proj_keys) def test_luong_precompute(self): src_length = 5 batch_size = 6 keys = torch.rand(size=(batch_size, src_length, self.key_size)) self.luong_att.compute_proj_keys(keys=keys) proj_keys_targets = torch.Tensor( [[[0.5362, 0.1826, 0.4716, 0.3245, 0.4122], [0.3819, 0.0934, 0.2750, 0.2311, 0.2378], [0.2246, 0.2934, 0.3999, 0.0519, 0.4430], [0.1271, 0.0636, 0.2444, 0.1294, 0.1659], [0.3494, 0.0372, 0.1326, 0.1908, 0.1295]], [[0.3363, 0.5984, 0.2090, -0.2695, 0.6584], [0.3098, 0.3608, 0.3623, 0.0098, 0.5004], [0.6133, 0.2568, 0.4264, 0.2688, 0.4716], [0.4058, 0.1438, 0.3043, 0.2127, 0.2971], [0.6604, 0.3490, 0.5228, 0.2593, 0.5967]], [[0.4224, 0.1182, 0.4883, 0.3403, 0.3458], [0.4257, 0.3757, -0.1431, -0.2208, 0.3383], [0.0681, 0.2540, 0.4165, 0.0269, 0.3934], [0.5341, 0.3288, 0.3937, 0.1532, 0.5132], [0.6244, 0.1647, 0.2378, 0.2548, 0.3196]], [[0.2222, 0.3380, 0.2374, -0.0748, 0.4212], [0.4042, 0.1373, 0.3308, 0.2317, 0.3011], [0.4740, 0.4829, -0.0853, -0.2634, 0.4623], [0.4540, 0.0645, 0.6046, 0.4632, 0.3459], [0.4744, 0.5098, -0.2441, -0.3713, 0.4265]], [[0.0314, 0.1189, 0.3825, 0.1119, 0.2548], [0.7057, 0.2725, 0.2426, 0.1979, 0.4285], [0.3967, 0.0223, 0.3664, 0.3488, 0.2107], [0.4311, 0.4695, 0.3035, -0.0640, 0.5914], [0.0797, 0.1038, 0.3847, 0.1476, 0.2486]], [[0.3379, 0.3671, 0.3622, 0.0166, 0.5097], [0.4051, 0.4552, -0.0709, -0.2616, 0.4339], [0.5379, 0.5037, 0.0074, -0.2046, 0.5243], [0.0250, 0.0544, 0.3859, 0.1679, 0.1976], [0.1880, 0.2725, 0.1849, -0.0598, 0.3383]]] ) self.assertTensorAlmostEqual(proj_keys_targets, self.luong_att.proj_keys)