示例#1
0
    def _build(self, batch_size):
        src_time_dim = 4
        vocab_size = 7

        emb = Embeddings(embedding_dim=self.emb_size,
                         vocab_size=vocab_size,
                         padding_idx=self.pad_index)

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     emb_dropout=self.dropout,
                                     vocab_size=vocab_size)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1

        encoder_hidden = None  # unused
        return src_mask, emb, decoder, encoder_output, encoder_hidden
示例#2
0
    def test_transformer_decoder_forward(self):
        batch_size = 2
        src_time_dim = 4
        trg_time_dim = 5
        vocab_size = 7

        trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size))

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     emb_dropout=self.dropout,
                                     vocab_size=vocab_size)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1
        trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1

        encoder_hidden = None  # unused
        decoder_hidden = None  # unused
        unrol_steps = None  # unused

        output, states, _, _ = decoder(trg_embed, encoder_output,
                                       encoder_hidden, src_mask, unrol_steps,
                                       decoder_hidden, trg_mask)

        output_target = torch.Tensor(
            [[[0.1946, 0.6144, -0.1925, -0.6967, 0.4466, -0.1085, 0.3400],
              [0.1857, 0.5558, -0.1314, -0.7783, 0.3980, -0.1736, 0.2347],
              [-0.0216, 0.3663, -0.2251, -0.5800, 0.2996, 0.0918, 0.2833],
              [0.0389, 0.4843, -0.1914, -0.6326, 0.3674, -0.0903, 0.2524],
              [0.0373, 0.3276, -0.2835, -0.6210, 0.2297, -0.0367, 0.1962]],
             [[0.0241, 0.4255, -0.2074, -0.6517, 0.3380, -0.0312, 0.2392],
              [0.1577, 0.4292, -0.1792, -0.7406, 0.2696, -0.1610, 0.2233],
              [0.0122, 0.4203, -0.2302, -0.6640, 0.2843, -0.0710, 0.2984],
              [0.0115, 0.3416, -0.2007, -0.6255, 0.2708, -0.0251, 0.2113],
              [0.0094, 0.4787, -0.1730, -0.6124, 0.4650, -0.0382, 0.1910]]])
        self.assertEqual(output_target.shape, output.shape)
        self.assertTensorAlmostEqual(output_target, output)

        greedy_predictions = output.argmax(-1)
        expect_predictions = output_target.argmax(-1)
        self.assertTensorEqual(expect_predictions, greedy_predictions)

        states_target = torch.Tensor([
            [[
                0.0491, 0.5322, 0.0327, -0.9208, -0.5646, -0.1138, 0.3416,
                -0.3235, 0.0350, -0.4339, 0.5837, 0.1022
            ],
             [
                 0.1838, 0.4832, -0.0498, -0.7803, -0.5348, -0.1162, 0.3667,
                 -0.3076, -0.0842, -0.4287, 0.6334, 0.1872
             ],
             [
                 0.0910, 0.3801, 0.0451, -0.7478, -0.4655, -0.1040, 0.6660,
                 -0.2871, 0.0544, -0.4561, 0.5823, 0.1653
             ],
             [
                 0.1064, 0.3970, -0.0691, -0.5924, -0.4410, -0.0984, 0.2759,
                 -0.3108, -0.0127, -0.4857, 0.6074, 0.0979
             ],
             [
                 0.0424, 0.3607, -0.0287, -0.5379, -0.4454, -0.0892, 0.4730,
                 -0.3021, -0.1303, -0.4889, 0.5257, 0.1394
             ]],
            [[
                0.1459, 0.4663, 0.0316, -0.7014, -0.4267, -0.0985, 0.5141,
                -0.2743, -0.0897, -0.4771, 0.5795, 0.1014
            ],
             [
                 0.2450, 0.4507, 0.0958, -0.6684, -0.4726, -0.0926, 0.4593,
                 -0.2969, -0.1612, -0.4224, 0.6054, 0.1698
             ],
             [
                 0.2137, 0.4132, 0.0327, -0.5304, -0.4519, -0.0934, 0.3898,
                 -0.2846, -0.0077, -0.4928, 0.6087, 0.1249
             ],
             [
                 0.1752, 0.3687, 0.0479, -0.5960, -0.4000, -0.0952, 0.5159,
                 -0.2926, -0.0668, -0.4628, 0.6031, 0.1711
             ],
             [
                 0.0396, 0.4577, -0.0789, -0.7109, -0.4049, -0.0989, 0.3596,
                 -0.2966, 0.0044, -0.4571, 0.6315, 0.1103
             ]]
        ])

        self.assertEqual(states_target.shape, states.shape)
        self.assertTensorAlmostEqual(states_target, states)
示例#3
0
    def test_transformer_decoder_forward(self):
        batch_size = 2
        src_time_dim = 4
        trg_time_dim = 5
        vocab_size = 7

        trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size))

        decoder = TransformerDecoder(num_layers=self.num_layers,
                                     num_heads=self.num_heads,
                                     hidden_size=self.hidden_size,
                                     ff_size=self.ff_size,
                                     dropout=self.dropout,
                                     emb_dropout=self.dropout,
                                     vocab_size=vocab_size)

        encoder_output = torch.rand(size=(batch_size, src_time_dim,
                                          self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1
        trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1

        encoder_hidden = None  # unused
        decoder_hidden = None  # unused
        unrol_steps = None  # unused

        output, states, _, _ = decoder(trg_embed, encoder_output,
                                       encoder_hidden, src_mask, unrol_steps,
                                       decoder_hidden, trg_mask)

        output_target = torch.Tensor(
            [[[0.1718, 0.5595, -0.1996, -0.6924, 0.4351, -0.0850, 0.2805],
              [0.0666, 0.4923, -0.1724, -0.6804, 0.3983, -0.1111, 0.2194],
              [-0.0315, 0.3673, -0.2320, -0.6100, 0.3019, 0.0422, 0.2514],
              [-0.0026, 0.3807, -0.2195, -0.6010, 0.3081, -0.0101, 0.2099],
              [-0.0172, 0.3384, -0.2853, -0.5799, 0.2470, 0.0312, 0.2518]],
             [[0.0284, 0.3918, -0.2010, -0.6472, 0.3646, -0.0296, 0.1791],
              [0.1017, 0.4387, -0.2031, -0.7084, 0.3051, -0.1354, 0.2511],
              [0.0155, 0.4274, -0.2061, -0.6702, 0.3085, -0.0617, 0.2830],
              [0.0227, 0.4067, -0.1697, -0.6463, 0.3277, -0.0423, 0.2333],
              [0.0133, 0.4409, -0.1186, -0.5694, 0.4450, 0.0290, 0.1643]]])
        self.assertEqual(output_target.shape, output.shape)
        self.assertTensorAlmostEqual(output_target, output)

        greedy_predictions = output.argmax(-1)
        expect_predictions = output_target.argmax(-1)
        self.assertTensorEqual(expect_predictions, greedy_predictions)

        states_target = torch.Tensor(
            [[[
                3.7535e-02, 5.3508e-01, 4.9478e-02, -9.1961e-01, -5.3966e-01,
                -1.0065e-01, 4.3053e-01, -3.0671e-01, -1.2724e-02, -4.1879e-01,
                5.9625e-01, 1.1887e-01
            ],
              [
                  1.3837e-01, 4.6963e-01, -3.7059e-02, -6.8479e-01,
                  -4.6042e-01, -1.0072e-01, 3.9374e-01, -3.0429e-01,
                  -5.4203e-02, -4.3680e-01, 6.4257e-01, 1.1424e-01
              ],
              [
                  1.0263e-01, 3.8331e-01, -2.5586e-02, -6.4478e-01,
                  -4.5860e-01, -1.0590e-01, 5.8806e-01, -2.8856e-01,
                  1.1084e-02, -4.7479e-01, 5.9094e-01, 1.6089e-01
              ],
              [
                  7.3408e-02, 3.7701e-01, -5.8783e-02, -6.2368e-01,
                  -4.4201e-01, -1.0237e-01, 5.2556e-01, -3.0821e-01,
                  -5.3345e-02, -4.5606e-01, 5.8259e-01, 1.2531e-01
              ],
              [
                  4.1206e-02, 3.6129e-01, -1.2955e-02, -5.8638e-01,
                  -4.6023e-01, -9.4267e-02, 5.5464e-01, -3.0029e-01,
                  -3.3974e-02, -4.8347e-01, 5.4088e-01, 1.2015e-01
              ]],
             [[
                 1.1017e-01, 4.7179e-01, 2.6402e-02, -7.2170e-01, -3.9778e-01,
                 -1.0226e-01, 5.3498e-01, -2.8369e-01, -1.1081e-01,
                 -4.6096e-01, 5.9517e-01, 1.3531e-01
             ],
              [
                  2.1947e-01, 4.6407e-01, 8.4276e-02, -6.3263e-01, -4.4953e-01,
                  -9.7334e-02, 4.0321e-01, -2.9893e-01, -1.0368e-01,
                  -4.5760e-01, 6.1378e-01, 1.3509e-01
              ],
              [
                  2.1437e-01, 4.1372e-01, 1.9859e-02, -5.7415e-01, -4.5025e-01,
                  -9.8621e-02, 4.1182e-01, -2.8410e-01, -1.2729e-03,
                  -4.8586e-01, 6.2318e-01, 1.4731e-01
              ],
              [
                  1.9153e-01, 3.8401e-01, 2.6096e-02, -6.2339e-01, -4.0685e-01,
                  -9.7387e-02, 4.1836e-01, -2.8648e-01, -1.7857e-02,
                  -4.7678e-01, 6.2907e-01, 1.7617e-01
              ],
              [
                  3.1713e-02, 3.7548e-01, -6.3005e-02, -7.9804e-01,
                  -3.6541e-01, -1.0398e-01, 4.2991e-01, -2.9607e-01,
                  2.1376e-04, -4.5897e-01, 6.1062e-01, 1.6142e-01
              ]]])

        self.assertEqual(states_target.shape, states.shape)
        self.assertTensorAlmostEqual(states_target, states)
示例#4
0
    def test_transformer_decoder_forward(self):
        torch.manual_seed(self.seed)
        batch_size = 2
        src_time_dim = 4
        trg_time_dim = 5
        vocab_size = 7

        trg_embed = torch.rand(size=(batch_size, trg_time_dim, self.emb_size))

        decoder = TransformerDecoder(
            num_layers=self.num_layers, num_heads=self.num_heads,
            hidden_size=self.hidden_size, ff_size=self.ff_size,
            dropout=self.dropout, emb_dropout=self.dropout,
            vocab_size=vocab_size)

        encoder_output = torch.rand(
            size=(batch_size, src_time_dim, self.hidden_size))

        for p in decoder.parameters():
            torch.nn.init.uniform_(p, -0.5, 0.5)

        src_mask = torch.ones(size=(batch_size, 1, src_time_dim)) == 1
        trg_mask = torch.ones(size=(batch_size, trg_time_dim, 1)) == 1

        output, states, _, _ = decoder(
            trg_embed, encoder_output, src_mask, trg_mask)

        output_target = torch.Tensor(
            [[[ 0.1765,  0.4578,  0.2345, -0.5303,  0.3862,  0.0964,  0.6882],
            [ 0.3363,  0.3907,  0.2210, -0.5414,  0.3770,  0.0748,  0.7344],
            [ 0.3275,  0.3729,  0.2797, -0.3519,  0.3341,  0.1605,  0.5403],
            [ 0.3081,  0.4513,  0.1900, -0.3443,  0.3072,  0.0570,  0.6652],
            [ 0.3253,  0.4315,  0.1227, -0.3371,  0.3339,  0.1129,  0.6331]],

            [[ 0.3235,  0.4836,  0.2337, -0.4019,  0.2831, -0.0260,  0.7013],
            [ 0.2800,  0.5662,  0.0469, -0.4156,  0.4246, -0.1121,  0.8110],
            [ 0.2968,  0.4777,  0.0652, -0.2706,  0.3146,  0.0732,  0.5362],
            [ 0.3108,  0.4910,  0.0774, -0.2341,  0.2873,  0.0404,  0.5909],
            [ 0.2338,  0.4371,  0.1350, -0.1292,  0.0673,  0.1034,  0.5356]]]
        )
        self.assertEqual(output_target.shape, output.shape)
        self.assertTensorAlmostEqual(output_target, output)

        greedy_predictions = output.argmax(-1)
        expect_predictions = output_target.argmax(-1)
        self.assertTensorEqual(expect_predictions, greedy_predictions)

        states_target = torch.Tensor(
            [[[ 8.3742e-01, -1.3161e-01,  2.1876e-01, -1.3920e-01, -9.1572e-01,
            2.3006e-01,  3.8328e-01, -1.6271e-01,  3.7370e-01, -1.2110e-01,
            -4.7549e-01, -4.0622e-01],
            [ 8.3609e-01, -2.9161e-02,  2.0583e-01, -1.3571e-01, -8.0510e-01,
            2.7630e-01,  4.8219e-01, -1.8863e-01,  1.1977e-01, -2.0179e-01,
            -4.4314e-01, -4.1228e-01],
            [ 8.5478e-01,  1.1368e-01,  2.0400e-01, -1.3059e-01, -8.1042e-01,
            1.6369e-01,  5.4244e-01, -2.9103e-01,  3.9919e-01, -3.3826e-01,
            -4.5423e-01, -4.2516e-01],
            [ 9.0388e-01,  1.1853e-01,  1.9927e-01, -1.1675e-01, -7.7208e-01,
            2.0686e-01,  4.6024e-01, -9.1610e-02,  3.9778e-01, -2.6214e-01,
            -4.7688e-01, -4.0807e-01],
            [ 8.9476e-01,  1.3646e-01,  2.0298e-01, -1.0910e-01, -8.2137e-01,
            2.8025e-01,  4.2538e-01, -1.1852e-01,  4.1497e-01, -3.7422e-01,
            -4.9212e-01, -3.9790e-01]],

            [[ 8.8745e-01, -2.5798e-02,  2.1483e-01, -1.8219e-01, -6.4821e-01,
            2.6279e-01,  3.9598e-01, -1.0423e-01,  3.0726e-01, -1.1315e-01,
            -4.7201e-01, -3.6979e-01],
            [ 7.5528e-01,  6.8919e-02,  2.2486e-01, -1.6395e-01, -7.9692e-01,
            3.7830e-01,  4.9367e-01,  2.4355e-02,  2.6674e-01, -1.1740e-01,
            -4.4945e-01, -3.6367e-01],
            [ 8.3467e-01,  1.7779e-01,  1.9504e-01, -1.6034e-01, -8.2783e-01,
            3.2627e-01,  5.0045e-01, -1.0181e-01,  4.4797e-01, -4.8046e-01,
            -3.7264e-01, -3.7392e-01],
            [ 8.4359e-01,  2.2699e-01,  1.9721e-01, -1.5768e-01, -7.5897e-01,
            3.3738e-01,  4.5559e-01, -1.0258e-01,  4.5782e-01, -3.8058e-01,
            -3.9275e-01, -3.8412e-01],
            [ 9.6349e-01,  1.6264e-01,  1.8207e-01, -1.6910e-01, -5.9304e-01,
            1.4468e-01,  2.4968e-01,  6.4794e-04,  5.4930e-01, -3.8420e-01,
            -4.2137e-01, -3.8016e-01]]]
        )

        self.assertEqual(states_target.shape, states.shape)
        self.assertTensorAlmostEqual(states_target, states)