def test_compare_with_full(self): local_att = LocalAttention(17, softmax_temp=1).eval() full_att = FullAttention(softmax_temp=1).eval() q, k, v, m1, m2, m3 = self._get_inputs(N=10, L=128, S=128, D=32) m = FullMask( torch.abs(torch.arange(128)[:, None] - torch.arange(128)[None]) < 9) v_full = full_att(q, k, v, m, m2, m3) v_local = local_att(q, k, v, m1, m2, m3) self.assertTrue(torch.allclose(v_full, v_local, atol=1e-5, rtol=1e-5))
def test_benchmark_cpu(self): q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64) att = LocalAttention(128) # warmup the cache for i in range(10): v_new = att(q, k, v, m1, m2, m3) # measure start = time.time() for i in range(10): v_new = att(q, k, v, m1, m2, m3) end = time.time() print("CPU Time taken:", (end - start) * 1000, "(ms)")
def test_benchmark_gpu(self): q, k, v, m1, m2, m3 = self._get_inputs(L=1024, S=1024, E=64, D=64, device="cuda") att = LocalAttention(128) # warmup the caches for i in range(10): v_new = att(q, k, v, m1, m2, m3) # measure start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for i in range(10): v_new = att(q, k, v, m1, m2, m3) end.record() torch.cuda.synchronize() print("GPU time taken:", start.elapsed_time(end), "(ms)")
def test_masked(self): att = LocalAttention(16, softmax_temp=1) q, k, v, m1, m2, m3 = self._get_inputs(N=3, L=64, S=64, D=32) m2 = m3 = LengthMask(torch.tensor([8, 16, 64], dtype=torch.long)) v_hat = att(q, k, v, m1, m2, m3) self.assertFalse(torch.any(torch.isnan(v_hat)))
def test_forward(self): att = LocalAttention(3, softmax_temp=1) q, k, v, m1, m2, m3 = self._get_inputs() v = att(q, k, v, m1, m2, m3) self.assertTrue(v.is_contiguous())