def xi_func(S):
     csp_xi = sym((cc_xi @ (S.P - S0.P).reshape(-1)).reshape(p, p))
     xi_amb = fr_ambient((aa_xiU @ (S.U - S0.U).reshape(-1) +
                          inct_xi.tU.reshape(-1)).reshape(m, p),
                         (aa_xiV @ (S.V - S0.V).reshape(-1) +
                          inct_xi.tV.reshape(-1)).reshape(n, p),
                         csp_xi + inct_xi.tP)
     return man.proj(S, xi_amb)
    def v_func_flat(S):
        # a function from the manifold
        # to ambient
        csp = sym((cc @ S.P.reshape(-1)).reshape(p, p))

        return fr_ambient((aaU @ S.U.reshape(-1) + bbU).reshape(m, p),
                          (aaV @ S.V.reshape(-1) + bbV).reshape(n,
                                                                p), csp + dd)
    def v_func(S):
        # a function from the manifold
        # to ambient
        csp = sym((cc @ (S.P - S0.P).reshape(-1)).reshape(p, p))

        return man.proj(
            S,
            fr_ambient((aaU @ (S.U - S0.U).reshape(-1) + intcU).reshape(m, p),
                       (aaV @ (S.V - S0.V).reshape(-1) + intcV).reshape(n, p),
                       csp + p_intc))
def N(man, X, B, C, D):
    al, bt, gm = (man.alpha, man.beta, man.gamma)
    U, V, Piv = (X.U, X.V, X.Pinv)
    Dm = asym(D)
    Dp = sym(D)
    bkPivDp = Piv @ Dp - Dp @ Piv
    U0 = null_space(X._U.T)
    V0 = null_space(X._V.T)
    U0B = U0 @ B
    V0C = V0 @ C
    return fr_ambient(U @ (-gm[1] * Dm + 1 / (al[1] + gm[1]) * bkPivDp) + U0B,
                      V @ (al[1] * Dm + 1 / (al[1] + gm[1]) * bkPivDp) + V0C,
                      1 / bt * Dp)
def test_geodesics():
    from scipy.linalg import expm
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = RealFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)
    X = man.rand()

    alf = alpha[1] / alpha[0]
    gmm = gamma[1] / gamma[0]

    def calc_Christoffel_Gamma(man, X, xi, eta):
        dprj = man.D_proj(X, xi, eta)
        proj_christoffel = man.proj_g_inv(X, man.christoffel_form(X, xi, eta))
        return proj_christoffel - dprj

    eta = man.randvec(X)
    g1 = calc_Christoffel_Gamma(man, X, eta, eta)
    g2 = man.christoffel_gamma(X, eta, eta)
    print(man._vec(g1 - g2))

    egrad = man._rand_ambient()
    print(man.base_inner_ambient(g1, egrad))
    print(man.rhess02_alt(X, eta, eta, egrad, 0))
    print(man.rhess02(X, eta, eta, egrad, man.zerovec(X)))

    t = 1
    AU = X.U.T @ eta.tU
    KU = eta.tU - X.U @ (X.U.T @ eta.tU)
    Up, RU = np.linalg.qr(KU)

    xU_mat = np.bmat([[2 * alf * AU, -RU.T], [RU, zeros((p, p))]])
    Ut = np.bmat([X.U, Up]) @ expm(t*xU_mat)[:, :p] @ \
        expm(t*(1-2*alf)*AU)
    xU_d_mat = xU_mat[:, :p].copy()
    xU_d_mat[:p, :] += (1 - 2 * alf) * AU
    Udt = np.bmat([X.U, Up]) @ expm(t*xU_mat) @ xU_d_mat @\
        expm(t*(1-2*alf)*AU)
    xU_dd_mat = xU_mat @ xU_d_mat + xU_d_mat @ ((1 - 2 * alf) * AU)
    Uddt = np.bmat([X.U, Up]) @ expm(t*xU_mat) @ xU_dd_mat @\
        expm(t*(1-2*alf)*AU)

    AV = X.V.T @ eta.tV
    KV = eta.tV - X.V @ (X.V.T @ eta.tV)
    Vp, RV = np.linalg.qr(KV)

    xV_mat = np.bmat([[2 * gmm * AV, -RV.T], [RV, zeros((p, p))]])
    Vt = np.bmat([X.V, Vp]) @ expm(t*xV_mat)[:, :p] @ \
        expm(t*(1-2*gmm)*AV)
    xV_d_mat = xV_mat[:, :p].copy()
    xV_d_mat[:p, :] += (1 - 2 * gmm) * AV
    Vdt = np.bmat([X.V, Vp]) @ expm(t*xV_mat) @ xV_d_mat @\
        expm(t*(1-2*gmm)*AV)
    xV_dd_mat = xV_mat @ xV_d_mat + xV_d_mat @ ((1 - 2 * gmm) * AV)
    Vddt = np.bmat([X.V, Vp]) @ expm(t*xV_mat) @ xV_dd_mat @\
        expm(t*(1-2*gmm)*AV)

    sqrtP = X.evec @ np.diag(np.sqrt(X.evl)) @ X.evec.T
    isqrtP = X.evec @ np.diag(1 / np.sqrt(X.evl)) @ X.evec.T
    Pinn = t * isqrtP @ eta.tP @ isqrtP
    ePinn = expm(Pinn)
    Pt = sqrtP @ ePinn @ sqrtP
    Pdt = eta.tP @ isqrtP @ ePinn @ sqrtP
    Pddt = eta.tP @ isqrtP @ ePinn @ isqrtP @ eta.tP

    Xt = fr_point(np.array(Ut), np.array(Vt), np.array(Pt))
    Xdt = fr_ambient(np.array(Udt), np.array(Vdt), np.array(Pdt))
    Xddt = fr_ambient(np.array(Uddt), np.array(Vddt), np.array(Pddt))
    gcheck = Xddt + calc_Christoffel_Gamma(man, Xt, Xdt, Xdt)

    print(man._vec(gcheck))
    Xt1 = man.exp(X, t * eta)
    print((Xt1.U - Xt.U))
    print((Xt1.V - Xt.V))
    print((Xt1.P - Xt.P))
 def ehess(S, xi):
     return fr_ambient(-2 * A @ (xi.tV @ S.P + S.V @ xi.tP),
                       -2 * A.T @ (xi.tU @ S.P + S.U @ xi.tP),
                       2 * (xi.tP - xi.tU.T @ A @ S.V - S.U.T @ A @ xi.tV))
 def egrad(S):
     return fr_ambient(-2 * A @ S.V @ S.P, -2 * A.T @ S.U @ S.P,
                       2 * (S.P - S.U.T @ A @ S.V))
 def eta_field(Sin):
     return man.proj(
         S,
         fr_ambient(mU1 @ (Sin.U - S.U) @ m2, mV1 @ (Sin.V - S.V) @ m2,
                    sym((m_p @ (Sin.P - S.P).reshape(-1)).reshape(
                        p, p)))) + eeta
 def omg_func(S):
     csp = sym((cc @ S.P.reshape(-1)).reshape(p, p))
     return fr_ambient(
         (aaU @ S.U.reshape(-1) + icpt.tU.reshape(-1)).reshape(m, p),
         (aaV @ S.V.reshape(-1) + icpt.tV.reshape(-1)).reshape(n, p),
         csp + icpt.tP)
def test_covariance_deriv():
    # now test full:
    # do covariant derivatives
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = RealFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)

    S = man.rand()

    aaU = randn(m * p, m * p)
    aaV = randn(n * p, n * p)
    cc = randn(p * p, p * p)
    icpt = man._rand_ambient()

    def omg_func(S):
        csp = sym((cc @ S.P.reshape(-1)).reshape(p, p))
        return fr_ambient(
            (aaU @ S.U.reshape(-1) + icpt.tU.reshape(-1)).reshape(m, p),
            (aaV @ S.V.reshape(-1) + icpt.tV.reshape(-1)).reshape(n, p),
            csp + icpt.tP)

    xi = man.randvec(S)
    egrad = omg_func(S)
    ecsp = sym((cc @ xi.tP.reshape(-1)).reshape(p, p))
    ehess = fr_ambient((aaU @ xi.tU.reshape(-1)).reshape(m, p),
                       (aaV @ xi.tV.reshape(-1)).reshape(n, p), ecsp)

    val1 = man.ehess2rhess(S, egrad, ehess, xi)

    def rgrad_func(W):
        return man.proj_g_inv(W, omg_func(W))

    if False:
        first = ehess
        a = man.J_g_inv(S, egrad)
        rgrad = man.proj_g_inv(S, egrad)
        second = man.D_g(S, xi, man.g_inv(S, egrad)).scalar_mul(-1)
        aout = man.solve_J_g_inv_Jst(S, a)
        third = man.proj(S, man.D_g_inv_Jst(S, xi, aout)).scalar_mul(-1)
        fourth = man.christoffel_form(S, xi, rgrad)
        val1a1 = man.proj_g_inv(S, first + second + fourth) + third
        print(check_zero(man._vec(val1 - val1a1)))
    elif True:
        d_xi_rgrad = num_deriv_amb(man, S, xi, rgrad_func)
        rgrad = man.proj_g_inv(S, egrad)
        fourth = man.christoffel_form(S, xi, rgrad)
        val1a = man.proj(S, d_xi_rgrad) + man.proj_g_inv(S, fourth)
        print(check_zero(man._vec(val1 - val1a)))

    # nabla_v_xi, dxi, cxxi
    val2a, _, _ = calc_covar_numeric(man, S, xi, omg_func)
    val2, _, _ = calc_covar_numeric(man, S, xi, rgrad_func)
    # val2_p = project(prj, val2)
    val2_p = man.proj(S, val2)
    # print(val1)
    # print(val2_p)
    print(check_zero(man._vec(val1) - man._vec(val2_p)))
    if True:
        H = xi
        valrangeA_ = ehess + man.g(S, man.D_proj(
            S, H, man.g_inv(S, egrad))) - man.D_g(
                S, H, man.g_inv(S, egrad)) +\
            man.christoffel_form(S, H, man.proj_g_inv(S, egrad))
        valrangeB = man.proj_g_inv(S, valrangeA_)
    valrange = man.ehess2rhess(S, egrad, ehess, xi)
    print(check_zero(man._vec(valrange) - man._vec(val2_p)))
    print(check_zero(man._vec(valrange) - man._vec(val1)))
    print(check_zero(man._vec(valrange) - man._vec(valrangeB)))
 def rand_vertical():
     oo = randn(man.p, man.p)
     oo = oo - oo.T
     return fr_ambient(X.U @ oo, X.V @ oo, -oo @ X.P + X.P @ oo)