def matrix_power(M, n): r"""Raise a square matrix, ``M``, to the (integer) power ``n``. This implementation uses exponentiation by squaring which is significantly faster than the naive implementation. The time complexity for exponentiation by squaring is :math: `\mathcal{O}((n \log M)^k)` Parameters ---------- M: TensorVariable n: int """ if n < 0: M = pinv(M) n = abs(n) # Shortcuts when 0 < n <= 3 if n == 0: return at.eye(M.shape[-2]) elif n == 1: return M elif n == 2: return tm.dot(M, M) elif n == 3: return tm.dot(tm.dot(M, M), M) result = z = None while n > 0: z = M if z is None else tm.dot(z, z) n, bit = divmod(n, 2) if bit: result = z if result is None else tm.dot(result, z) return result
def test_jax_eye(): """Tests jaxification of the Eye operator""" out = aet.eye(3) out_fg = FunctionGraph([], [out]) compare_jax_and_py(out_fg, [])