Example #1
0
def matrix_power(M, n):
    r"""Raise a square matrix, ``M``, to the (integer) power ``n``.

    This implementation uses exponentiation by squaring which is
    significantly faster than the naive implementation.
    The time complexity for exponentiation by squaring is
    :math: `\mathcal{O}((n \log M)^k)`

    Parameters
    ----------
    M: TensorVariable
    n: int

    """
    if n < 0:
        M = pinv(M)
        n = abs(n)

    # Shortcuts when 0 < n <= 3
    if n == 0:
        return at.eye(M.shape[-2])

    elif n == 1:
        return M

    elif n == 2:
        return tm.dot(M, M)

    elif n == 3:
        return tm.dot(tm.dot(M, M), M)

    result = z = None

    while n > 0:
        z = M if z is None else tm.dot(z, z)
        n, bit = divmod(n, 2)
        if bit:
            result = z if result is None else tm.dot(result, z)

    return result
Example #2
0
def test_jax_eye():
    """Tests jaxification of the Eye operator"""
    out = aet.eye(3)
    out_fg = FunctionGraph([], [out])

    compare_jax_and_py(out_fg, [])