Beispiel #1
0
    def test_correctness(self):
        # Prepare the inputs
        N = 10
        H = 4
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, L, H, E)
        k = torch.rand(N, L, H, E)
        v = torch.rand(N, L, H, M)
        m1 = TriangularCausalMask(L)
        m2 = LengthMask(torch.full((N,), L))
        m3 = LengthMask(torch.full((N,), L))
        att = FullAttention()
        rec_att = RecurrentFullAttention()
        att.eval()
        rec_att.eval()

        v1 = att(q, k, v, m1, m2, m3)
        v2 = []
        memory = None
        for i in range(L):
            v2i, memory = rec_att(q[:, i], k[:, i], v[:, i], memory)
            v2.append(v2i)
        v2 = torch.stack(v2, dim=1)
        self.assertLess(torch.abs(v1-v2).max(), 1e-5)
Beispiel #2
0
    def test_full_attention_forward(self):
        d_model = 128
        n_heads = 4
        transformer = RecurrentTransformerEncoder([
            RecurrentTransformerEncoderLayer(
                RecurrentAttentionLayer(
                    RecurrentFullAttention(),
                    d_model,
                    n_heads
                ),
                d_model,
                n_heads
            )
            for i in range(6)
        ])

        xs = []
        memory = None
        for i in range(7):
            x, memory = transformer(torch.rand(10, d_model), state=memory)
            xs.append(x)
        for i in range(7):
            self.assertEqual(xs[i].shape, (10, d_model))
        self.assertEqual(len(memory), 6)
        for i in range(6):
            self.assertEqual(len(memory[i]), 2)
            self.assertEqual(memory[i][0].shape, (10, n_heads, 7, 32))
            self.assertEqual(memory[i][1].shape, (10, n_heads, 7, 32))
Beispiel #3
0
    def test_forward(self):
        # Prepare the inputs
        N = 10
        H = 4
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, H, E)
        k = torch.rand(N, H, E)
        v = torch.rand(N, H, M)
        memory = [
            torch.rand(N, H, L, E),
            torch.rand(N, H, L, M)
        ]

        # Test the attention module
        att = RecurrentFullAttention(softmax_temp=1)
        v_new, mem_new = att(q, k, v)
        self.assertEqual(v_new.shape, (N, H, M))
        self.assertEqual(len(mem_new), 2)
        self.assertEqual(mem_new[0].shape, (N, H, 1, E))
        self.assertEqual(mem_new[1].shape, (N, H, 1, M))
        v_new, mem_new = att(q, k, v, mem_new)
        self.assertEqual(v_new.shape, (N, H, M))
        self.assertEqual(len(mem_new), 2)
        self.assertEqual(mem_new[0].shape, (N, H, 2, E))
        self.assertEqual(mem_new[1].shape, (N, H, 2, M))

        v_new, mem_new = att(q, k, v, memory)
        self.assertEqual(v_new.shape, (N, H, M))
        self.assertEqual(len(mem_new), 2)
        self.assertEqual(mem_new[0].shape, (N, H, L+1, E))
        self.assertEqual(mem_new[1].shape, (N, H, L+1, M))
    def test_mask_creation(self):
        N = 10
        L = 42
        S = 100
        D = 1024
        x = torch.rand(N, D)
        m = torch.rand(N, S, D)

        rdec = RecurrentTransformerDecoder([
            RecurrentTransformerDecoderLayer(
                RecurrentAttentionLayer(RecurrentFullAttention(), D, 4),
                RecurrentCrossAttentionLayer(RecurrentCrossFullAttention(), D,
                                             4), D) for i in range(4)
        ])
        rdec(x, m)
Beispiel #5
0
    def test_benchmark_cpu(self):
        # Prepare the inputs
        N = 10
        H = 12
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, H, E)
        k = torch.rand(N, H, E)
        v = torch.rand(N, H, M)
        memory = None
        att = RecurrentFullAttention(softmax_temp=1)

        start = time.time()
        for i in range(100):
            v, memory = att(q, k, v, memory)
        end = time.time()
        print("CPU Time taken:", (end-start)*1000, "(ms)")
Beispiel #6
0
    def test_benchmark_gpu(self):
        # Prepare the inputs
        N = 10
        H = 12
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, H, E).cuda()
        k = torch.rand(N, H, E).cuda()
        v = torch.rand(N, H, M).cuda()
        memory = None
        att = RecurrentFullAttention(softmax_temp=1)

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        for i in range(100):
            v, memory = att(q, k, v, memory)
        end.record()
        torch.cuda.synchronize()
        print("GPU time taken:", start.elapsed_time(end), "(ms)")