def setUp(self): self.key_size = 3 self.query_size = 5 self.hidden_size = 7 seed = 42 torch.manual_seed(seed) self.bahdanau_att = BahdanauAttention(hidden_size=self.hidden_size, key_size=self.key_size, query_size=self.query_size)
def create_attention(config): key_size = 2 * config["hidden_size"] query_size = config["hidden_size"] if config["attention"] == "bahdanau": return BahdanauAttention(query_size, key_size, config["hidden_size"]) elif config["attention"] == "luong": return LuongAttention(query_size, key_size) else: raise ValueError("Unknown attention: {}".format(config["attention"]))
def __init__(self, rnn_type: str = "gru", emb_size: int = 0, hidden_size: int = 0, encoder: Encoder = None, attention: str = "bahdanau", num_layers: int = 1, vocab_size: int = 0, dropout: float = 0., emb_dropout: float = 0., hidden_dropout: float = 0., init_hidden: str = "bridge", input_feeding: bool = True, freeze: bool = False, **kwargs) -> None: """ Create a recurrent decoder with attention. :param rnn_type: rnn type, valid options: "lstm", "gru" :param emb_size: target embedding size :param hidden_size: size of the RNN :param encoder: encoder connected to this decoder :param attention: type of attention, valid options: "bahdanau", "luong" :param num_layers: number of recurrent layers :param vocab_size: target vocabulary size :param hidden_dropout: Is applied to the input to the attentional layer. :param dropout: Is applied between RNN layers. :param emb_dropout: Is applied to the RNN input (word embeddings). :param init_hidden: If "bridge" (default), the decoder hidden states are initialized from a projection of the last encoder state, if "zeros" they are initialized with zeros, if "last" they are identical to the last encoder state (only if they have the same size) :param input_feeding: Use Luong's input feeding. :param freeze: Freeze the parameters of the decoder during training. :param kwargs: """ super().__init__() self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False) self.type = rnn_type self.hidden_dropout = torch.nn.Dropout(p=hidden_dropout, inplace=False) self.hidden_size = hidden_size self.emb_size = emb_size rnn = nn.GRU if rnn_type == "gru" else nn.LSTM self.input_feeding = input_feeding if self.input_feeding: # Luong-style # combine embedded prev word +attention vector before feeding to rnn self.rnn_input_size = emb_size + hidden_size else: # just feed prev word embedding self.rnn_input_size = emb_size # the decoder RNN self.rnn = rnn(self.rnn_input_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.) # combine output with context vector before output layer (Luong-style) self.att_vector_layer = nn.Linear(hidden_size + encoder.output_size, hidden_size, bias=True) self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False) self._output_size = vocab_size if attention == "bahdanau": self.attention = BahdanauAttention(hidden_size=hidden_size, key_size=encoder.output_size, query_size=hidden_size) elif attention == "luong": self.attention = LuongAttention(hidden_size=hidden_size, key_size=encoder.output_size) else: raise ConfigurationError("Unknown attention mechanism: %s. " "Valid options: 'bahdanau', 'luong'." % attention) self.num_layers = num_layers self.hidden_size = hidden_size # to initialize from the final encoder state of last layer self.init_hidden_option = init_hidden if self.init_hidden_option == "bridge": self.bridge_layer = nn.Linear(encoder.output_size, hidden_size, bias=True) elif self.init_hidden_option == "last": if encoder.output_size != self.hidden_size: if encoder.output_size != 2 * self.hidden_size: # bidirectional raise ConfigurationError( "For initializing the decoder state with the " "last encoder state, their sizes have to match " "(encoder: {} vs. decoder: {})".format( encoder.output_size, self.hidden_size)) if freeze: freeze_params(self)
def __init__(self, type: str = "gru", emb_size: int = 0, hidden_size: int = 0, encoder: Encoder = None, attention: str = "bahdanau", num_layers: int = 0, vocab_size: int = 0, dropout: float = 0., hidden_dropout: float = 0., bridge: bool = False, input_feeding: bool = True, freeze: bool = False, **kwargs): """ Create a recurrent decoder. If `bridge` is True, the decoder hidden states are initialized from a projection of the encoder states, else they are initialized with zeros. :param type: :param emb_size: :param hidden_size: :param encoder: :param attention: :param num_layers: :param vocab_size: :param dropout: :param hidden_dropout: :param bridge: :param input_feeding: :param freeze: freeze the parameters of the decoder during training :param kwargs: """ super(RecurrentDecoder, self).__init__() self.rnn_input_dropout = torch.nn.Dropout(p=dropout, inplace=False) self.type = type self.hidden_dropout = torch.nn.Dropout(p=hidden_dropout, inplace=False) self.hidden_size = hidden_size rnn = nn.GRU if type == "gru" else nn.LSTM self.input_feeding = input_feeding if self.input_feeding: # Luong-style # combine embedded prev word +attention vector before feeding to rnn self.rnn_input_size = emb_size + hidden_size else: # just feed prev word embedding self.rnn_input_size = emb_size # the decoder RNN self.rnn = rnn(self.rnn_input_size, hidden_size, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.) # combine output with context vector before output layer (Luong-style) self.att_vector_layer = nn.Linear( hidden_size + encoder.output_size, hidden_size, bias=True) self.output_layer = nn.Linear(hidden_size, vocab_size, bias=False) self.output_size = vocab_size if attention == "bahdanau": self.attention = BahdanauAttention(hidden_size=hidden_size, key_size=encoder.output_size, query_size=hidden_size) elif attention == "luong": self.attention = LuongAttention(hidden_size=hidden_size, key_size=encoder.output_size) else: raise ValueError("Unknown attention mechanism: %s" % attention) self.num_layers = num_layers self.hidden_size = hidden_size # to initialize from the final encoder state of last layer self.bridge = bridge if self.bridge: self.bridge_layer = nn.Linear( encoder.output_size, hidden_size, bias=True) if freeze: freeze_params(self)
class TestBahdanauAttention(TensorTestCase): def setUp(self): self.key_size = 3 self.query_size = 5 self.hidden_size = 7 seed = 42 torch.manual_seed(seed) self.bahdanau_att = BahdanauAttention(hidden_size=self.hidden_size, key_size=self.key_size, query_size=self.query_size) def test_bahdanau_attention_size(self): self.assertIsNone(self.bahdanau_att.key_layer.bias) # no bias self.assertIsNone(self.bahdanau_att.query_layer.bias) # no bias self.assertEqual(self.bahdanau_att.key_layer.weight.shape, torch.Size([self.hidden_size, self.key_size])) self.assertEqual(self.bahdanau_att.query_layer.weight.shape, torch.Size([self.hidden_size, self.query_size])) self.assertEqual(self.bahdanau_att.energy_layer.weight.shape, torch.Size([1, self.hidden_size])) self.assertIsNone(self.bahdanau_att.energy_layer.bias) def test_bahdanau_forward(self): src_length = 5 trg_length = 4 batch_size = 6 queries = torch.rand(size=(batch_size, trg_length, self.query_size)) keys = torch.rand(size=(batch_size, src_length, self.key_size)) mask = torch.ones(size=(batch_size, 1, src_length)).byte() # introduce artificial padding areas mask[0, 0, -3:] = 0 mask[1, 0, -2:] = 0 mask[4, 0, -1:] = 0 for t in range(trg_length): c, att = None, None try: # should raise an AssertionException (missing pre-computation) query = queries[:, t, :].unsqueeze(1) c, att = self.bahdanau_att(query=query, mask=mask, values=keys) except AssertionError: pass self.assertIsNone(c) self.assertIsNone(att) # now with pre-computation self.bahdanau_att.compute_proj_keys(keys=keys) self.assertIsNotNone(self.bahdanau_att.proj_keys) self.assertEqual(self.bahdanau_att.proj_keys.shape, torch.Size( [batch_size, src_length, self.hidden_size])) contexts = [] attention_probs = [] for t in range(trg_length): c, att = None, None try: # should not raise an AssertionException query = queries[:, t, :].unsqueeze(1) c, att = self.bahdanau_att(query=query, mask=mask, values=keys) except AssertionError: self.fail() self.assertIsNotNone(c) self.assertIsNotNone(att) self.assertEqual(self.bahdanau_att.proj_query.shape, torch.Size([batch_size, 1, self.hidden_size])) contexts.append(c) attention_probs.append(att) self.assertEqual(len(attention_probs), trg_length) self.assertEqual(len(contexts), trg_length) contexts = torch.cat(contexts, dim=1) attention_probs = torch.cat(attention_probs, dim=1) self.assertEqual(contexts.shape, torch.Size( [batch_size, trg_length, self.key_size])) self.assertEqual(attention_probs.shape, torch.Size([batch_size, trg_length, src_length])) contexts_target = torch.Tensor( [[[0.5080, 0.5832, 0.5614], [0.5096, 0.5816, 0.5596], [0.5092, 0.5820, 0.5601], [0.5079, 0.5833, 0.5615]], [[0.4709, 0.5817, 0.3091], [0.4720, 0.5793, 0.3063], [0.4704, 0.5825, 0.3102], [0.4709, 0.5814, 0.3090]], [[0.4394, 0.4482, 0.6526], [0.4390, 0.4475, 0.6522], [0.4391, 0.4479, 0.6538], [0.4391, 0.4479, 0.6533]], [[0.5283, 0.3441, 0.3938], [0.5297, 0.3457, 0.3956], [0.5306, 0.3466, 0.3966], [0.5274, 0.3431, 0.3926]], [[0.4079, 0.4145, 0.2439], [0.4064, 0.4156, 0.2445], [0.4077, 0.4147, 0.2439], [0.4067, 0.4153, 0.2444]], [[0.5649, 0.5749, 0.4960], [0.5660, 0.5763, 0.4988], [0.5658, 0.5754, 0.4984], [0.5662, 0.5766, 0.4991]]] ) self.assertTensorAlmostEqual(contexts_target, contexts) attention_probs_targets = torch.Tensor( [[[0.4904, 0.5096, 0.0000, 0.0000, 0.0000], [0.4859, 0.5141, 0.0000, 0.0000, 0.0000], [0.4871, 0.5129, 0.0000, 0.0000, 0.0000], [0.4906, 0.5094, 0.0000, 0.0000, 0.0000]], [[0.3314, 0.3278, 0.3408, 0.0000, 0.0000], [0.3337, 0.3230, 0.3433, 0.0000, 0.0000], [0.3301, 0.3297, 0.3402, 0.0000, 0.0000], [0.3312, 0.3275, 0.3413, 0.0000, 0.0000]], [[0.1977, 0.2047, 0.2040, 0.1936, 0.1999], [0.1973, 0.2052, 0.2045, 0.1941, 0.1988], [0.1987, 0.2046, 0.2046, 0.1924, 0.1996], [0.1984, 0.2047, 0.2044, 0.1930, 0.1995]], [[0.1963, 0.2041, 0.2006, 0.1942, 0.2047], [0.1954, 0.2065, 0.2011, 0.1934, 0.2036], [0.1947, 0.2074, 0.2014, 0.1928, 0.2038], [0.1968, 0.2028, 0.2006, 0.1949, 0.2049]], [[0.2455, 0.2414, 0.2588, 0.2543, 0.0000], [0.2450, 0.2447, 0.2566, 0.2538, 0.0000], [0.2458, 0.2417, 0.2586, 0.2540, 0.0000], [0.2452, 0.2438, 0.2568, 0.2542, 0.0000]], [[0.1999, 0.1888, 0.1951, 0.2009, 0.2153], [0.2035, 0.1885, 0.1956, 0.1972, 0.2152], [0.2025, 0.1885, 0.1950, 0.1980, 0.2159], [0.2044, 0.1884, 0.1955, 0.1970, 0.2148]]] ) self.assertTensorAlmostEqual(attention_probs_targets, attention_probs) def test_bahdanau_precompute_None(self): self.assertIsNone(self.bahdanau_att.proj_keys) self.assertIsNone(self.bahdanau_att.proj_query) def test_bahdanau_precompute(self): src_length = 5 batch_size = 6 keys = torch.rand(size=(batch_size, src_length, self.key_size)) self.bahdanau_att.compute_proj_keys(keys=keys) proj_keys_targets = torch.Tensor( [[[0.4042, 0.1373, 0.3308, 0.2317, 0.3011, 0.2978, -0.0975], [0.4740, 0.4829, -0.0853, -0.2634, 0.4623, 0.0333, -0.2702], [0.4540, 0.0645, 0.6046, 0.4632, 0.3459, 0.4631, -0.0919], [0.4744, 0.5098, -0.2441, -0.3713, 0.4265, -0.0407, -0.2527], [0.0314, 0.1189, 0.3825, 0.1119, 0.2548, 0.1239, -0.1921]], [[0.7057, 0.2725, 0.2426, 0.1979, 0.4285, 0.3727, -0.1126], [0.3967, 0.0223, 0.3664, 0.3488, 0.2107, 0.3531, -0.0095], [0.4311, 0.4695, 0.3035, -0.0640, 0.5914, 0.1713, -0.3695], [0.0797, 0.1038, 0.3847, 0.1476, 0.2486, 0.1568, -0.1672], [0.3379, 0.3671, 0.3622, 0.0166, 0.5097, 0.1845, -0.3207]], [[0.4051, 0.4552, -0.0709, -0.2616, 0.4339, 0.0126, -0.2682], [0.5379, 0.5037, 0.0074, -0.2046, 0.5243, 0.0969, -0.2953], [0.0250, 0.0544, 0.3859, 0.1679, 0.1976, 0.1471, -0.1392], [0.1880, 0.2725, 0.1849, -0.0598, 0.3383, 0.0693, -0.2329], [0.0759, 0.1006, 0.0955, -0.0048, 0.1361, 0.0400, -0.0913]], [[-0.0207, 0.1266, 0.5529, 0.1728, 0.3192, 0.1611, -0.2560], [0.5713, 0.2364, 0.0718, 0.0801, 0.3141, 0.2455, -0.0729], [0.1574, 0.1162, 0.3591, 0.1572, 0.2602, 0.1838, -0.1510], [0.1357, 0.0192, 0.1817, 0.1391, 0.1037, 0.1389, -0.0277], [0.3088, 0.2804, 0.2024, -0.0045, 0.3680, 0.1386, -0.2127]], [[0.1181, 0.0899, 0.1139, 0.0329, 0.1390, 0.0744, -0.0758], [0.0713, 0.2682, 0.4111, 0.0129, 0.4044, 0.0985, -0.3177], [0.5340, 0.1713, 0.5365, 0.3679, 0.4262, 0.4373, -0.1456], [0.3902, -0.0242, 0.4498, 0.4313, 0.1997, 0.4012, 0.0075], [0.1764, 0.1531, -0.0564, -0.0876, 0.1390, 0.0129, -0.0714]], [[0.3772, 0.3725, 0.3053, -0.0012, 0.4982, 0.1808, -0.3006], [0.4391, -0.0472, 0.3379, 0.4136, 0.1434, 0.3918, 0.0687], [0.3697, 0.2313, 0.4745, 0.2100, 0.4348, 0.3000, -0.2242], [0.8427, 0.3705, 0.1227, 0.1079, 0.4890, 0.3604, -0.1305], [0.3526, 0.3477, 0.1473, -0.0740, 0.4132, 0.1138, -0.2452]]] ) self.assertTensorAlmostEqual(proj_keys_targets, self.bahdanau_att.proj_keys)