Пример #1
0
    def test_linear_attention_forward(self):
        d_model = 128
        n_heads = 4
        transformer = RecurrentTransformerEncoder([
            RecurrentTransformerEncoderLayer(
                RecurrentAttentionLayer(
                    RecurrentLinearAttention(),
                    d_model,
                    n_heads
                ),
                d_model,
                n_heads
            )
            for i in range(6)
        ])

        xs = []
        memory = None
        for i in range(7):
            x, memory = transformer(torch.rand(10, d_model), state=memory)
            xs.append(x)
        for i in range(7):
            self.assertEqual(xs[i].shape, (10, d_model))
        self.assertEqual(len(memory), 6)
        for i in range(6):
            self.assertEqual(len(memory[i]), 2)
            self.assertEqual(memory[i][0].shape, (10, n_heads, 32, 32))
            self.assertEqual(memory[i][1].shape, (10, n_heads, 32))
    def test_forward(self):
        att = RecurrentAttentionLayer(
            self._assert_sizes_attention((10, 4, 25), (10, 4, 25),
                                         (10, 4, 25)), 100, 4)
        v, m = att(torch.rand(10, 100), torch.rand(10, 100),
                   torch.rand(10, 100), "test memory")
        self.assertEqual(v.shape, (10, 100))
        self.assertEqual(m, "test memory")

        att = RecurrentAttentionLayer(self._assert_sizes_attention(
            (10, 4, 32), (10, 4, 32), (10, 4, 64)),
                                      100,
                                      4,
                                      d_keys=32,
                                      d_values=64)
        v, m = att(torch.rand(10, 100), torch.rand(10, 100),
                   torch.rand(10, 100), "test memory")
        self.assertEqual(v.shape, (10, 100))
        self.assertEqual(m, "test memory")
Пример #3
0
    def test_compare_with_batch(self):
        N = 10
        L = 42
        S = 100
        D = 1024
        E = D // 4
        x = torch.rand(N, L, D)
        m = torch.rand(N, S, D)

        tests = [("full", FullAttention, FullAttention, RecurrentFullAttention,
                  RecurrentCrossFullAttention),
                 ("linear", partial(CausalLinearAttention,
                                    E), partial(LinearAttention, E),
                  partial(RecurrentLinearAttention,
                          E), partial(RecurrentCrossLinearAttention, E))]

        for name, a1, a2, a3, a4 in tests:
            dec = TransformerDecoder([
                TransformerDecoderLayer(AttentionLayer(a1(), D, 4),
                                        AttentionLayer(a2(), D, 4), D)
                for i in range(4)
            ])
            rdec = RecurrentTransformerDecoder([
                RecurrentTransformerDecoderLayer(
                    RecurrentAttentionLayer(a3(), D, 4),
                    RecurrentCrossAttentionLayer(a4(), D, 4), D)
                for i in range(4)
            ])
            dec.eval()
            rdec.eval()
            rdec.load_state_dict(dec.state_dict())

            x_mask = TriangularCausalMask(L)
            x_length = LengthMask(torch.full((N, ), L, dtype=torch.int64))
            m_mask = FullMask(L, S)
            m_length = LengthMask(torch.full((N, ), S, dtype=torch.int64))

            y1 = dec(x,
                     m,
                     x_mask=x_mask,
                     x_length_mask=x_length,
                     memory_mask=m_mask,
                     memory_length_mask=m_length)
            state = None
            y2 = []
            for i in range(L):
                y2i, state = rdec(x[:, i],
                                  m,
                                  memory_length_mask=m_length,
                                  state=state)
                y2.append(y2i)
            y2 = torch.stack(y2, dim=1)

            self.assertLess(torch.abs(y1 - y2).max(), 1e-5)
    def test_mask_creation(self):
        N = 10
        L = 42
        S = 100
        D = 1024
        x = torch.rand(N, D)
        m = torch.rand(N, S, D)

        rdec = RecurrentTransformerDecoder([
            RecurrentTransformerDecoderLayer(
                RecurrentAttentionLayer(RecurrentFullAttention(), D, 4),
                RecurrentCrossAttentionLayer(RecurrentCrossFullAttention(), D,
                                             4), D) for i in range(4)
        ])
        rdec(x, m)
Пример #5
0
    def __init__(self, alphabet_size, hidden_size, num_window_components,
                 num_mixture_components):
        super(HandwritingGenerator, self).__init__()
        self.alphabet_size = alphabet_size
        self.hidden_size = hidden_size
        self.num_window_components = num_window_components
        self.num_mixture_components = num_mixture_components
        # print(num_window_components)
        # print(num_mixture_components)
        # print(hidden_size)
        self.input_size = input_size = 3
        n_heads_1 = 2
        n_heads_2 = 10
        query_dimensions = 1
        self.n_pre_layers = 2
        self.n_layers = 4
        # n_heads_2 = 4
        # First LSTM layer, takes as input a tuple (x, y, eol)
        # self.lstm1_layer = LSTM(input_size=3, hidden_size=hidden_size, batch_first=True)
        # [
        #     TransformerEncoderLayer(
        #         AttentionLayer(FullAttention(), 768, 12),
        #         768,
        #         12,
        #         activation="gelu"
        #     ) for l in range(12)
        # ],
        # norm_layer=torch.nn.LayerNorm(768)

        self.lstm1_layer = LSTM(input_size=input_size,
                                hidden_size=hidden_size,
                                batch_first=True)

        # self.transformers1_layers = [
        #     RecurrentTransformerEncoderLayer(
        #         RecurrentAttentionLayer(RecurrentLinearAttention(query_dimensions), input_size, n_heads_1),
        #         input_size,
        #         hidden_size,
        #         activation="gelu"
        #     ) for l in range(self.n_pre_layers)
        # ]
        # self.norm1_layer = torch.nn.Linear(input_size, hidden_size)

        # Gaussian Window layer
        self.window_layer = GaussianWindow(
            input_size=hidden_size, num_components=num_window_components)
        # Second LSTM layer, takes as input the concatenation of the input,
        # the output of the first LSTM layer
        # and the output of the Window layer
        # self.lstm2_layer = LSTM(
        #     input_size=3 + hidden_size + alphabet_size + 1,
        #     hidden_size=hidden_size,
        #     batch_first=True,
        # )
        self.transformers2_layers = [
            RecurrentTransformerEncoderLayer(
                RecurrentAttentionLayer(
                    RecurrentLinearAttention(query_dimensions),
                    3 + hidden_size + alphabet_size + 1, n_heads_2),
                3 + hidden_size + alphabet_size + 1,
                # RecurrentAttentionLayer(RecurrentLinearAttention(query_dimensions), hidden_size, n_heads_2),
                # hidden_size,
                hidden_size,
                activation="gelu") for l in range(self.n_layers)
        ]

        # Third LSTM layer, takes as input the concatenation of the output of the first LSTM layer,
        # the output of the second LSTM layer
        # and the output of the Window layer
        # self.lstm3_layer = LSTM(
        #     input_size=hidden_size, hidden_size=hidden_size, batch_first=True
        # )
        # print( 3 + hidden_size + alphabet_size + 1)
        # print(hidden_size)
        self.norm2_layer = torch.nn.LayerNorm(3 + hidden_size + alphabet_size +
                                              1)
        # self.norm2_layer = torch.nn.LayerNorm(hidden_size)
        # self.norm2_layer = torch.nn.Linear(hidden_size)

        # Mixture Density Network Layer
        self.output_layer = MDN(
            input_size=3 + hidden_size + alphabet_size + 1,
            num_mixtures=num_mixture_components
            # input_size=hidden_size, num_mixtures=num_mixture_components
        )

        # Hidden State Variables
        self.prev_kappa = None
        # self.hidden1 = None
        self.hidden1 = None
        # self.hidden1 = [None] * self.n_pre_layers
        self.hidden2 = [None] * self.n_layers
        # self.hidden3 = None

        # Initiliaze parameters
        self.reset_parameters()