def test_apply_stride(self): tokens = torch.arange(10).long().reshape((2, 5)) # No stride apply_stride(tokens, [(100, 0, 0), (100, 0, 0)]) expected = torch.arange(10).long().reshape((2, 5)) self.assertEqual(expected.tolist(), tokens.tolist())
def test_apply_stride_real_stride(self): # Stride aligned tokens = torch.arange(10).long().reshape((2, 5)) apply_stride(tokens, [(100, 20, 0), (100, 0, 20)]) self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist()) # Stride rounded tokens = torch.arange(10).long().reshape((2, 5)) apply_stride(tokens, [(100, 15, 0), (100, 0, 15)]) self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist()) # No stride rounded tokens = torch.arange(10).long().reshape((2, 5)) apply_stride(tokens, [(100, 5, 0), (100, 0, 5)]) self.assertEqual([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], tokens.tolist())
def test_apply_stride_with_padding(self): # Stride aligned tokens = torch.arange(10).long().reshape((2, 5)) apply_stride(tokens, [(100, 20, 0), (60, 0, 20)]) self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 6, 6, 6]], tokens.tolist())