def _build(self, batch_size): src_time_dim = 4 vocab_size = 7 emb = Embeddings(embedding_dim=self.emb_size, vocab_size=vocab_size, padding_idx=self.pad_index) decoder = TransformerDecoder(num_layers=self.num_layers, num_heads=self.num_heads, hidden_size=self.hidden_size, ff_size=self.ff_size, dropout=self.dropout, emb_dropout=self.dropout, vocab_size=vocab_size) encoder_output = torch.rand(size=(batch_size, src_time_dim, self.hidden_size)) for p in decoder.parameters(): torch.nn.init.uniform_(p, -0.5, 0.5) src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1 encoder_hidden = None # unused return src_mask, emb, decoder, encoder_output, encoder_hidden
def test_transformer_decoder_forward(self): batch_size = 2 src_time_dim = 4 trg_time_dim = 5 vocab_size = 7 trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size)) decoder = TransformerDecoder(num_layers=self.num_layers, num_heads=self.num_heads, hidden_size=self.hidden_size, ff_size=self.ff_size, dropout=self.dropout, emb_dropout=self.dropout, vocab_size=vocab_size) encoder_output = torch.rand(size=(batch_size, src_time_dim, self.hidden_size)) for p in decoder.parameters(): torch.nn.init.uniform_(p, -0.5, 0.5) src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1 trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1 encoder_hidden = None # unused decoder_hidden = None # unused unrol_steps = None # unused output, states, _, _ = decoder(trg_embed, encoder_output, encoder_hidden, src_mask, unrol_steps, decoder_hidden, trg_mask) output_target = torch.Tensor( [[[0.1946, 0.6144, -0.1925, -0.6967, 0.4466, -0.1085, 0.3400], [0.1857, 0.5558, -0.1314, -0.7783, 0.3980, -0.1736, 0.2347], [-0.0216, 0.3663, -0.2251, -0.5800, 0.2996, 0.0918, 0.2833], [0.0389, 0.4843, -0.1914, -0.6326, 0.3674, -0.0903, 0.2524], [0.0373, 0.3276, -0.2835, -0.6210, 0.2297, -0.0367, 0.1962]], [[0.0241, 0.4255, -0.2074, -0.6517, 0.3380, -0.0312, 0.2392], [0.1577, 0.4292, -0.1792, -0.7406, 0.2696, -0.1610, 0.2233], [0.0122, 0.4203, -0.2302, -0.6640, 0.2843, -0.0710, 0.2984], [0.0115, 0.3416, -0.2007, -0.6255, 0.2708, -0.0251, 0.2113], [0.0094, 0.4787, -0.1730, -0.6124, 0.4650, -0.0382, 0.1910]]]) self.assertEqual(output_target.shape, output.shape) self.assertTensorAlmostEqual(output_target, output) greedy_predictions = output.argmax(-1) expect_predictions = output_target.argmax(-1) self.assertTensorEqual(expect_predictions, greedy_predictions) states_target = torch.Tensor([ [[ 0.0491, 0.5322, 0.0327, -0.9208, -0.5646, -0.1138, 0.3416, -0.3235, 0.0350, -0.4339, 0.5837, 0.1022 ], [ 0.1838, 0.4832, -0.0498, -0.7803, -0.5348, -0.1162, 0.3667, -0.3076, -0.0842, -0.4287, 0.6334, 0.1872 ], [ 0.0910, 0.3801, 0.0451, -0.7478, -0.4655, -0.1040, 0.6660, -0.2871, 0.0544, -0.4561, 0.5823, 0.1653 ], [ 0.1064, 0.3970, -0.0691, -0.5924, -0.4410, -0.0984, 0.2759, -0.3108, -0.0127, -0.4857, 0.6074, 0.0979 ], [ 0.0424, 0.3607, -0.0287, -0.5379, -0.4454, -0.0892, 0.4730, -0.3021, -0.1303, -0.4889, 0.5257, 0.1394 ]], [[ 0.1459, 0.4663, 0.0316, -0.7014, -0.4267, -0.0985, 0.5141, -0.2743, -0.0897, -0.4771, 0.5795, 0.1014 ], [ 0.2450, 0.4507, 0.0958, -0.6684, -0.4726, -0.0926, 0.4593, -0.2969, -0.1612, -0.4224, 0.6054, 0.1698 ], [ 0.2137, 0.4132, 0.0327, -0.5304, -0.4519, -0.0934, 0.3898, -0.2846, -0.0077, -0.4928, 0.6087, 0.1249 ], [ 0.1752, 0.3687, 0.0479, -0.5960, -0.4000, -0.0952, 0.5159, -0.2926, -0.0668, -0.4628, 0.6031, 0.1711 ], [ 0.0396, 0.4577, -0.0789, -0.7109, -0.4049, -0.0989, 0.3596, -0.2966, 0.0044, -0.4571, 0.6315, 0.1103 ]] ]) self.assertEqual(states_target.shape, states.shape) self.assertTensorAlmostEqual(states_target, states)
def test_transformer_decoder_forward(self): batch_size = 2 src_time_dim = 4 trg_time_dim = 5 vocab_size = 7 trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size)) decoder = TransformerDecoder(num_layers=self.num_layers, num_heads=self.num_heads, hidden_size=self.hidden_size, ff_size=self.ff_size, dropout=self.dropout, emb_dropout=self.dropout, vocab_size=vocab_size) encoder_output = torch.rand(size=(batch_size, src_time_dim, self.hidden_size)) for p in decoder.parameters(): torch.nn.init.uniform_(p, -0.5, 0.5) src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1 trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1 encoder_hidden = None # unused decoder_hidden = None # unused unrol_steps = None # unused output, states, _, _ = decoder(trg_embed, encoder_output, encoder_hidden, src_mask, unrol_steps, decoder_hidden, trg_mask) output_target = torch.Tensor( [[[0.1718, 0.5595, -0.1996, -0.6924, 0.4351, -0.0850, 0.2805], [0.0666, 0.4923, -0.1724, -0.6804, 0.3983, -0.1111, 0.2194], [-0.0315, 0.3673, -0.2320, -0.6100, 0.3019, 0.0422, 0.2514], [-0.0026, 0.3807, -0.2195, -0.6010, 0.3081, -0.0101, 0.2099], [-0.0172, 0.3384, -0.2853, -0.5799, 0.2470, 0.0312, 0.2518]], [[0.0284, 0.3918, -0.2010, -0.6472, 0.3646, -0.0296, 0.1791], [0.1017, 0.4387, -0.2031, -0.7084, 0.3051, -0.1354, 0.2511], [0.0155, 0.4274, -0.2061, -0.6702, 0.3085, -0.0617, 0.2830], [0.0227, 0.4067, -0.1697, -0.6463, 0.3277, -0.0423, 0.2333], [0.0133, 0.4409, -0.1186, -0.5694, 0.4450, 0.0290, 0.1643]]]) self.assertEqual(output_target.shape, output.shape) self.assertTensorAlmostEqual(output_target, output) greedy_predictions = output.argmax(-1) expect_predictions = output_target.argmax(-1) self.assertTensorEqual(expect_predictions, greedy_predictions) states_target = torch.Tensor( [[[ 3.7535e-02, 5.3508e-01, 4.9478e-02, -9.1961e-01, -5.3966e-01, -1.0065e-01, 4.3053e-01, -3.0671e-01, -1.2724e-02, -4.1879e-01, 5.9625e-01, 1.1887e-01 ], [ 1.3837e-01, 4.6963e-01, -3.7059e-02, -6.8479e-01, -4.6042e-01, -1.0072e-01, 3.9374e-01, -3.0429e-01, -5.4203e-02, -4.3680e-01, 6.4257e-01, 1.1424e-01 ], [ 1.0263e-01, 3.8331e-01, -2.5586e-02, -6.4478e-01, -4.5860e-01, -1.0590e-01, 5.8806e-01, -2.8856e-01, 1.1084e-02, -4.7479e-01, 5.9094e-01, 1.6089e-01 ], [ 7.3408e-02, 3.7701e-01, -5.8783e-02, -6.2368e-01, -4.4201e-01, -1.0237e-01, 5.2556e-01, -3.0821e-01, -5.3345e-02, -4.5606e-01, 5.8259e-01, 1.2531e-01 ], [ 4.1206e-02, 3.6129e-01, -1.2955e-02, -5.8638e-01, -4.6023e-01, -9.4267e-02, 5.5464e-01, -3.0029e-01, -3.3974e-02, -4.8347e-01, 5.4088e-01, 1.2015e-01 ]], [[ 1.1017e-01, 4.7179e-01, 2.6402e-02, -7.2170e-01, -3.9778e-01, -1.0226e-01, 5.3498e-01, -2.8369e-01, -1.1081e-01, -4.6096e-01, 5.9517e-01, 1.3531e-01 ], [ 2.1947e-01, 4.6407e-01, 8.4276e-02, -6.3263e-01, -4.4953e-01, -9.7334e-02, 4.0321e-01, -2.9893e-01, -1.0368e-01, -4.5760e-01, 6.1378e-01, 1.3509e-01 ], [ 2.1437e-01, 4.1372e-01, 1.9859e-02, -5.7415e-01, -4.5025e-01, -9.8621e-02, 4.1182e-01, -2.8410e-01, -1.2729e-03, -4.8586e-01, 6.2318e-01, 1.4731e-01 ], [ 1.9153e-01, 3.8401e-01, 2.6096e-02, -6.2339e-01, -4.0685e-01, -9.7387e-02, 4.1836e-01, -2.8648e-01, -1.7857e-02, -4.7678e-01, 6.2907e-01, 1.7617e-01 ], [ 3.1713e-02, 3.7548e-01, -6.3005e-02, -7.9804e-01, -3.6541e-01, -1.0398e-01, 4.2991e-01, -2.9607e-01, 2.1376e-04, -4.5897e-01, 6.1062e-01, 1.6142e-01 ]]]) self.assertEqual(states_target.shape, states.shape) self.assertTensorAlmostEqual(states_target, states)
def test_transformer_decoder_forward(self): torch.manual_seed(self.seed) batch_size = 2 src_time_dim = 4 trg_time_dim = 5 vocab_size = 7 trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size)) decoder = TransformerDecoder( num_layers=self.num_layers, num_heads=self.num_heads, hidden_size=self.hidden_size, ff_size=self.ff_size, dropout=self.dropout, emb_dropout=self.dropout, vocab_size=vocab_size) encoder_output = torch.rand( size=(batch_size, src_time_dim, self.hidden_size)) for p in decoder.parameters(): torch.nn.init.uniform_(p, -0.5, 0.5) src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1 trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1 output, states, _, _ = decoder( trg_embed, encoder_output, src_mask, trg_mask) output_target = torch.Tensor( [[[ 0.1765, 0.4578, 0.2345, -0.5303, 0.3862, 0.0964, 0.6882], [ 0.3363, 0.3907, 0.2210, -0.5414, 0.3770, 0.0748, 0.7344], [ 0.3275, 0.3729, 0.2797, -0.3519, 0.3341, 0.1605, 0.5403], [ 0.3081, 0.4513, 0.1900, -0.3443, 0.3072, 0.0570, 0.6652], [ 0.3253, 0.4315, 0.1227, -0.3371, 0.3339, 0.1129, 0.6331]], [[ 0.3235, 0.4836, 0.2337, -0.4019, 0.2831, -0.0260, 0.7013], [ 0.2800, 0.5662, 0.0469, -0.4156, 0.4246, -0.1121, 0.8110], [ 0.2968, 0.4777, 0.0652, -0.2706, 0.3146, 0.0732, 0.5362], [ 0.3108, 0.4910, 0.0774, -0.2341, 0.2873, 0.0404, 0.5909], [ 0.2338, 0.4371, 0.1350, -0.1292, 0.0673, 0.1034, 0.5356]]] ) self.assertEqual(output_target.shape, output.shape) self.assertTensorAlmostEqual(output_target, output) greedy_predictions = output.argmax(-1) expect_predictions = output_target.argmax(-1) self.assertTensorEqual(expect_predictions, greedy_predictions) states_target = torch.Tensor( [[[ 8.3742e-01, -1.3161e-01, 2.1876e-01, -1.3920e-01, -9.1572e-01, 2.3006e-01, 3.8328e-01, -1.6271e-01, 3.7370e-01, -1.2110e-01, -4.7549e-01, -4.0622e-01], [ 8.3609e-01, -2.9161e-02, 2.0583e-01, -1.3571e-01, -8.0510e-01, 2.7630e-01, 4.8219e-01, -1.8863e-01, 1.1977e-01, -2.0179e-01, -4.4314e-01, -4.1228e-01], [ 8.5478e-01, 1.1368e-01, 2.0400e-01, -1.3059e-01, -8.1042e-01, 1.6369e-01, 5.4244e-01, -2.9103e-01, 3.9919e-01, -3.3826e-01, -4.5423e-01, -4.2516e-01], [ 9.0388e-01, 1.1853e-01, 1.9927e-01, -1.1675e-01, -7.7208e-01, 2.0686e-01, 4.6024e-01, -9.1610e-02, 3.9778e-01, -2.6214e-01, -4.7688e-01, -4.0807e-01], [ 8.9476e-01, 1.3646e-01, 2.0298e-01, -1.0910e-01, -8.2137e-01, 2.8025e-01, 4.2538e-01, -1.1852e-01, 4.1497e-01, -3.7422e-01, -4.9212e-01, -3.9790e-01]], [[ 8.8745e-01, -2.5798e-02, 2.1483e-01, -1.8219e-01, -6.4821e-01, 2.6279e-01, 3.9598e-01, -1.0423e-01, 3.0726e-01, -1.1315e-01, -4.7201e-01, -3.6979e-01], [ 7.5528e-01, 6.8919e-02, 2.2486e-01, -1.6395e-01, -7.9692e-01, 3.7830e-01, 4.9367e-01, 2.4355e-02, 2.6674e-01, -1.1740e-01, -4.4945e-01, -3.6367e-01], [ 8.3467e-01, 1.7779e-01, 1.9504e-01, -1.6034e-01, -8.2783e-01, 3.2627e-01, 5.0045e-01, -1.0181e-01, 4.4797e-01, -4.8046e-01, -3.7264e-01, -3.7392e-01], [ 8.4359e-01, 2.2699e-01, 1.9721e-01, -1.5768e-01, -7.5897e-01, 3.3738e-01, 4.5559e-01, -1.0258e-01, 4.5782e-01, -3.8058e-01, -3.9275e-01, -3.8412e-01], [ 9.6349e-01, 1.6264e-01, 1.8207e-01, -1.6910e-01, -5.9304e-01, 1.4468e-01, 2.4968e-01, 6.4794e-04, 5.4930e-01, -3.8420e-01, -4.2137e-01, -3.8016e-01]]] ) self.assertEqual(states_target.shape, states.shape) self.assertTensorAlmostEqual(states_target, states)