def test_correctness(self):
        # Prepare the inputs
        N = 10
        H = 4
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, L, H, E)
        k = torch.rand(N, L, H, E)
        v = torch.rand(N, L, H, M)
        m1 = TriangularCausalMask(L)
        m2 = LengthMask(torch.full((N, ), L, dtype=torch.long))
        m3 = LengthMask(torch.full((N, ), L, dtype=torch.long))
        att = CausalLinearAttention()
        rec_att = RecurrentLinearAttention()
        att.eval()
        rec_att.eval()

        v1 = att(q, k, v, m1, m2, m3)
        v2 = []
        memory = None
        for i in range(L):
            v2i, memory = rec_att(q[:, i], k[:, i], v[:, i], memory)
            v2.append(v2i)
        v2 = torch.stack(v2, dim=1)
        self.assertLess(torch.abs(v1 - v2).max(), 1e-5)
Ejemplo n.º 2
0
    def test_linear_attention_forward(self):
        d_model = 128
        n_heads = 4
        transformer = RecurrentTransformerEncoder([
            RecurrentTransformerEncoderLayer(
                RecurrentAttentionLayer(
                    RecurrentLinearAttention(),
                    d_model,
                    n_heads
                ),
                d_model,
                n_heads
            )
            for i in range(6)
        ])

        xs = []
        memory = None
        for i in range(7):
            x, memory = transformer(torch.rand(10, d_model), state=memory)
            xs.append(x)
        for i in range(7):
            self.assertEqual(xs[i].shape, (10, d_model))
        self.assertEqual(len(memory), 6)
        for i in range(6):
            self.assertEqual(len(memory[i]), 2)
            self.assertEqual(memory[i][0].shape, (10, n_heads, 32, 32))
            self.assertEqual(memory[i][1].shape, (10, n_heads, 32))
    def test_forward(self):
        # Prepare the inputs
        N = 10
        H = 4
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, H, E)
        k = torch.rand(N, H, E)
        v = torch.rand(N, H, M)
        memory = [torch.rand(N, H, E, M), torch.rand(N, H, E)]

        # Test the attention module
        att = RecurrentLinearAttention()
        v_new, mem_new = att(q, k, v)
        self.assertEqual(v_new.shape, (N, H, M))
        self.assertEqual(len(mem_new), 2)
        self.assertEqual(mem_new[0].shape, (N, H, E, M))
        self.assertEqual(mem_new[1].shape, (N, H, E))
        v_new, mem_new = att(q, k, v, mem_new)
        self.assertEqual(v_new.shape, (N, H, M))
        self.assertEqual(len(mem_new), 2)
        self.assertEqual(mem_new[0].shape, (N, H, E, M))
        self.assertEqual(mem_new[1].shape, (N, H, E))

        v_new, mem_new = att(q, k, v, memory)
        self.assertEqual(v_new.shape, (N, H, M))
        self.assertEqual(len(mem_new), 2)
        self.assertEqual(mem_new[0].shape, (N, H, E, M))
        self.assertEqual(mem_new[1].shape, (N, H, E))
    def test_benchmark_cpu(self):
        # Prepare the inputs
        N = 10
        H = 12
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, H, E)
        k = torch.rand(N, H, E)
        v = torch.rand(N, H, M)
        memory = None
        att = RecurrentLinearAttention()

        start = time.time()
        for i in range(100):
            v, memory = att(q, k, v, memory)
        end = time.time()
        print("CPU Time taken:", (end - start) * 1000, "(ms)")
    def test_benchmark_gpu(self):
        # Prepare the inputs
        N = 10
        H = 12
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, H, E).cuda()
        k = torch.rand(N, H, E).cuda()
        v = torch.rand(N, H, M).cuda()
        memory = None
        att = RecurrentLinearAttention()

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        for i in range(100):
            v, memory = att(q, k, v, memory)
        end.record()
        torch.cuda.synchronize()
        print("GPU time taken:", start.elapsed_time(end), "(ms)")
Ejemplo n.º 6
0
    def __init__(self, alphabet_size, hidden_size, num_window_components,
                 num_mixture_components):
        super(HandwritingGenerator, self).__init__()
        self.alphabet_size = alphabet_size
        self.hidden_size = hidden_size
        self.num_window_components = num_window_components
        self.num_mixture_components = num_mixture_components
        # print(num_window_components)
        # print(num_mixture_components)
        # print(hidden_size)
        self.input_size = input_size = 3
        n_heads_1 = 2
        n_heads_2 = 10
        query_dimensions = 1
        self.n_pre_layers = 2
        self.n_layers = 4
        # n_heads_2 = 4
        # First LSTM layer, takes as input a tuple (x, y, eol)
        # self.lstm1_layer = LSTM(input_size=3, hidden_size=hidden_size, batch_first=True)
        # [
        #     TransformerEncoderLayer(
        #         AttentionLayer(FullAttention(), 768, 12),
        #         768,
        #         12,
        #         activation="gelu"
        #     ) for l in range(12)
        # ],
        # norm_layer=torch.nn.LayerNorm(768)

        self.lstm1_layer = LSTM(input_size=input_size,
                                hidden_size=hidden_size,
                                batch_first=True)

        # self.transformers1_layers = [
        #     RecurrentTransformerEncoderLayer(
        #         RecurrentAttentionLayer(RecurrentLinearAttention(query_dimensions), input_size, n_heads_1),
        #         input_size,
        #         hidden_size,
        #         activation="gelu"
        #     ) for l in range(self.n_pre_layers)
        # ]
        # self.norm1_layer = torch.nn.Linear(input_size, hidden_size)

        # Gaussian Window layer
        self.window_layer = GaussianWindow(
            input_size=hidden_size, num_components=num_window_components)
        # Second LSTM layer, takes as input the concatenation of the input,
        # the output of the first LSTM layer
        # and the output of the Window layer
        # self.lstm2_layer = LSTM(
        #     input_size=3 + hidden_size + alphabet_size + 1,
        #     hidden_size=hidden_size,
        #     batch_first=True,
        # )
        self.transformers2_layers = [
            RecurrentTransformerEncoderLayer(
                RecurrentAttentionLayer(
                    RecurrentLinearAttention(query_dimensions),
                    3 + hidden_size + alphabet_size + 1, n_heads_2),
                3 + hidden_size + alphabet_size + 1,
                # RecurrentAttentionLayer(RecurrentLinearAttention(query_dimensions), hidden_size, n_heads_2),
                # hidden_size,
                hidden_size,
                activation="gelu") for l in range(self.n_layers)
        ]

        # Third LSTM layer, takes as input the concatenation of the output of the first LSTM layer,
        # the output of the second LSTM layer
        # and the output of the Window layer
        # self.lstm3_layer = LSTM(
        #     input_size=hidden_size, hidden_size=hidden_size, batch_first=True
        # )
        # print( 3 + hidden_size + alphabet_size + 1)
        # print(hidden_size)
        self.norm2_layer = torch.nn.LayerNorm(3 + hidden_size + alphabet_size +
                                              1)
        # self.norm2_layer = torch.nn.LayerNorm(hidden_size)
        # self.norm2_layer = torch.nn.Linear(hidden_size)

        # Mixture Density Network Layer
        self.output_layer = MDN(
            input_size=3 + hidden_size + alphabet_size + 1,
            num_mixtures=num_mixture_components
            # input_size=hidden_size, num_mixtures=num_mixture_components
        )

        # Hidden State Variables
        self.prev_kappa = None
        # self.hidden1 = None
        self.hidden1 = None
        # self.hidden1 = [None] * self.n_pre_layers
        self.hidden2 = [None] * self.n_layers
        # self.hidden3 = None

        # Initiliaze parameters
        self.reset_parameters()