Beispiel #1
0
def svd_jvp_rule(primals, tangents, full_matrices, compute_uv):
    if full_matrices:
        #TODO: implement full matrices case, documented here: https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
        raise NotImplementedError(
            "Singular value decomposition JVP not implemented for full matrices"
        )

    A, = primals
    dA, = tangents
    s, U, Vt = svd_p.bind(A, full_matrices=False, compute_uv=True)

    k = s.shape[-1]
    Ut, V = np.conj(U).T, np.conj(Vt).T
    s_dim = s[..., None, :]
    dS = np.dot(np.dot(Ut, dA), V)
    ds = np.real(np.diag(dS))
    F = 1 / (np.square(s_dim) - np.square(s_dim.T) + np.eye(k)) - np.eye(k)
    dSS = s_dim * dS
    SdS = s_dim.T * dS
    dU = np.dot(U, F * (dSS + dSS.T))
    dV = np.dot(V, F * (SdS + SdS.T))

    m, n = A.shape[-2], A.shape[-1]
    if m > n:
        dU = dU + np.dot(np.eye(m) - np.dot(U, Ut), np.dot(dA, V)) / s_dim
    if n > m:
        dV = dV + np.dot(np.eye(n) - np.dot(V, Vt), np.dot(np.conj(dA).T,
                                                           U)) / s_dim
    return core.pack((s, U, Vt)), core.pack((ds, dU, dV.T))
Beispiel #2
0
def multi_dot(arrays, *, precision=None):
    n = len(arrays)
    # optimization only makes sense for len(arrays) > 2
    if n < 2:
        raise ValueError("Expecting at least two arrays.")
    elif n == 2:
        return jnp.dot(arrays[0], arrays[1], precision=precision)

    arrays = [jnp.asarray(a) for a in arrays]

    # save original ndim to reshape the result array into the proper form later
    ndim_first, ndim_last = arrays[0].ndim, arrays[-1].ndim
    # Explicitly convert vectors to 2D arrays to keep the logic of the internal
    # _multi_dot_* functions as simple as possible.
    if arrays[0].ndim == 1:
        arrays[0] = jnp.atleast_2d(arrays[0])
    if arrays[-1].ndim == 1:
        arrays[-1] = jnp.atleast_2d(arrays[-1]).T
    _assert2d(*arrays)

    # _multi_dot_three is much faster than _multi_dot_matrix_chain_order
    if n == 3:
        result = _multi_dot_three(*arrays, precision)
    else:
        order = _multi_dot_matrix_chain_order(arrays)
        result = _multi_dot(arrays, order, 0, n - 1, precision)

    # return proper shape
    if ndim_first == 1 and ndim_last == 1:
        return result[0, 0]  # scalar
    elif ndim_first == 1 or ndim_last == 1:
        return result.ravel()  # 1-D
    else:
        return result
Beispiel #3
0
def _multi_dot(arrays, order, i, j, precision):
    """Actually do the multiplication with the given order."""
    if i == j:
        return arrays[i]
    else:
        return jnp.dot(_multi_dot(arrays, order, i, order[i, j], precision),
                      _multi_dot(arrays, order, order[i, j] + 1, j, precision),
                      precision=precision)
Beispiel #4
0
def _multi_dot_three(A, B, C, precision):
    """
    Find the best order for three arrays and do the multiplication.
    For three arguments `_multi_dot_three` is approximately 15 times faster
    than `_multi_dot_matrix_chain_order`
    """
    a0, a1b0 = A.shape
    b1c0, c1 = C.shape
    # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1
    cost1 = a0 * b1c0 * (a1b0 + c1)
    # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1
    cost2 = a1b0 * c1 * (a0 + b1c0)

    if cost1 < cost2:
        return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision)
    else:
        return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)