def test_apply_stride(self):
        tokens = torch.arange(10).long().reshape((2, 5))

        # No stride
        apply_stride(tokens, [(100, 0, 0), (100, 0, 0)])

        expected = torch.arange(10).long().reshape((2, 5))
        self.assertEqual(expected.tolist(), tokens.tolist())
    def test_apply_stride_real_stride(self):
        # Stride aligned
        tokens = torch.arange(10).long().reshape((2, 5))
        apply_stride(tokens, [(100, 20, 0), (100, 0, 20)])
        self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist())

        # Stride rounded
        tokens = torch.arange(10).long().reshape((2, 5))
        apply_stride(tokens, [(100, 15, 0), (100, 0, 15)])
        self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 7, 8, 8]], tokens.tolist())

        # No stride rounded
        tokens = torch.arange(10).long().reshape((2, 5))
        apply_stride(tokens, [(100, 5, 0), (100, 0, 5)])
        self.assertEqual([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], tokens.tolist())
 def test_apply_stride_with_padding(self):
     # Stride aligned
     tokens = torch.arange(10).long().reshape((2, 5))
     apply_stride(tokens, [(100, 20, 0), (60, 0, 20)])
     self.assertEqual([[1, 1, 2, 3, 4], [5, 6, 6, 6, 6]], tokens.tolist())