def rand_vertical():
     oo = crandn(man.p, man.p)
     oo = oo - oo.T.conj()
     return fr_ambient(
         X.U @ oo,
         X.V @ oo,
         -oo @ X.P + X.P @oo)
 def xi_func(S):
     csp_xi = hsym((cc_xi @ (S.P-S0.P).reshape(-1)).reshape(p, p))
     xi_amb = fr_ambient(
         (aa_xiU @ (S.U-S0.U).reshape(-1) +
          inct_xi.tU.reshape(-1)).reshape(m, p),
         (aa_xiV @ (S.V-S0.V).reshape(-1) +
          inct_xi.tV.reshape(-1)).reshape(n, p),
         csp_xi + inct_xi.tP)
     return man.proj(S, xi_amb)
 def v_func(S):
     # a function from the manifold
     # to ambient
     csp = hsym((cc @ (S.P-S0.P).reshape(-1)).reshape(p, p))
     
     return man.proj(S, fr_ambient(
         (aaU @ (S.U-S0.U).reshape(-1) + intcU).reshape(m, p),
         (aaV @ (S.V-S0.V).reshape(-1) + intcV).reshape(n, p),
         csp + p_intc))
 def v_func_flat(S):
     # a function from the manifold
     # to ambient
     csp = hsym((cc @ S.P.reshape(-1)).reshape(p, p))
     
     return fr_ambient(
         (aaU @ S.U.reshape(-1) + bbU).reshape(m, p),
         (aaV @ S.V.reshape(-1) + bbV).reshape(n, p),
         csp + dd)
def N(man, X, B, C, D):
    al, bt, gm = (man.alpha, man.beta, man.gamma)
    U, V, Piv = (X.U, X.V, X.Pinv)
    Dm = ahsym(D)
    Dp = hsym(D)
    bkPivDp = Piv @ Dp - Dp @ Piv
    U0 = null_space(X._U.T.conj())
    V0 = null_space(X._V.T.conj())
    U0B = U0 @ B
    V0C = V0 @ C
    return fr_ambient(
        U @ (-gm[1]*Dm + 1/(al[1]+gm[1])*bkPivDp)+U0B,
        V @(al[1]*Dm + 1/(al[1]+gm[1])*bkPivDp)+V0C,
        1/bt*Dp)
def test_geodesics():
    from scipy.linalg import expm
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 5
    n = 6
    p = 3
    man = ComplexFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)
    X = man.rand()

    alf = alpha[1]/alpha[0]
    gmm = gamma[1]/gamma[0]
    
    def calc_Christoffel_Gamma(man, X, xi, eta):
        dprj = man.D_proj(X, xi, eta)
        proj_christoffel = man.proj_g_inv(
            X, man.christoffel_form(X, xi, eta))
        return proj_christoffel - dprj
        
    eta = man.randvec(X)
    g1 = calc_Christoffel_Gamma(man, X, eta, eta)
    g2 = man.christoffel_gamma(X, eta, eta)
    print(man._vec(g1-g2))

    egrad = man._rand_ambient()
    print(man.base_inner_ambient(g1, egrad))
    print(man.rhess02_alt(X, eta, eta, egrad, 0))
    print(man.rhess02(X, eta, eta, egrad, man.zerovec(X)))

    t = 2
    AU = X.U.T.conj() @ eta.tU
    KU = eta.tU - X.U @ (X.U.T.conj() @ eta.tU)
    Up, RU = np.linalg.qr(KU)

    xU_mat = np.bmat([[2*alf*AU, -RU.T.conj()], [RU, zeros((p, p))]])
    Ut = np.bmat([X.U, Up]) @ expm(t*xU_mat)[:, :p] @ \
        expm(t*(1-2*alf)*AU)
    xU_d_mat = xU_mat[:, :p].copy()
    xU_d_mat[:p, :] += (1-2*alf) * AU
    Udt = np.bmat([X.U, Up]) @ expm(t*xU_mat) @ xU_d_mat @\
        expm(t*(1-2*alf)*AU)
    xU_dd_mat = xU_mat @ xU_d_mat + xU_d_mat @ ((1-2*alf)*AU)
    Uddt = np.bmat([X.U, Up]) @ expm(t*xU_mat) @ xU_dd_mat @\
        expm(t*(1-2*alf)*AU)

    AV = X.V.T.conj() @ eta.tV
    KV = eta.tV - X.V @ (X.V.T.conj() @ eta.tV)
    Vp, RV = np.linalg.qr(KV)

    xV_mat = np.bmat([[2*gmm*AV, -RV.T.conj()], [RV, zeros((p, p))]])
    Vt = np.bmat([X.V, Vp]) @ expm(t*xV_mat)[:, :p] @ \
        expm(t*(1-2*gmm)*AV)
    xV_d_mat = xV_mat[:, :p].copy()
    xV_d_mat[:p, :] += (1-2*gmm) * AV
    Vdt = np.bmat([X.V, Vp]) @ expm(t*xV_mat) @ xV_d_mat @\
        expm(t*(1-2*gmm)*AV)
    xV_dd_mat = xV_mat @ xV_d_mat + xV_d_mat @ ((1-2*gmm)*AV)
    Vddt = np.bmat([X.V, Vp]) @ expm(t*xV_mat) @ xV_dd_mat @\
        expm(t*(1-2*gmm)*AV)
    
    sqrtP = X.evec @ np.diag(np.sqrt(X.evl)) @ X.evec.T.conj()
    isqrtP = X.evec @ np.diag(1/np.sqrt(X.evl)) @ X.evec.T.conj()
    Pinn = t*[email protected]@isqrtP
    ePinn = expm(Pinn)
    Pt = sqrtP@ePinn@sqrtP
    Pdt = eta.tP@isqrtP@ePinn@sqrtP
    Pddt = eta.tP@isqrtP@ ePinn@[email protected]
    
    Xt = fr_point(np.array(Ut),
                  np.array(Vt),
                  np.array(Pt))
    Xdt = fr_ambient(np.array(Udt),
                     np.array(Vdt),
                     np.array(Pdt))
    Xddt = fr_ambient(np.array(Uddt),
                      np.array(Vddt),
                      np.array(Pddt))
    gcheck = Xddt + calc_Christoffel_Gamma(man, Xt, Xdt, Xdt)
    
    print(man._vec(gcheck))
    Xt1 = man.exp(X, t*eta)
    print((Xt1.U - Xt.U))
    print((Xt1.V - Xt.V))
    print((Xt1.P - Xt.P))
 def ehess(S, xi):
     return fr_ambient(-2*A @ (xi.tV @ S.P + S.V @ xi.tP),
                       -2*A.T.conj() @ (xi.tU @S.P + [email protected]),
                       2*(xi.tP - xi.tU.T.conj()@[email protected] -
                          S.U.T.conj()@[email protected]))
 def egrad(S):
     return fr_ambient(-2*A @ S.V @ S.P,
                       -2*A.T.conj() @ S.U @S.P,
                       2*(S.P-S.U.T.conj() @ A @ S.V))
 def eta_field(Sin):
     return man.proj(S, fr_ambient(
         mU1 @ (Sin.U - S.U) @ m2,
         mV1 @ (Sin.V - S.V) @ m2,
         hsym((m_p @ (Sin.P - S.P).reshape(-1)).reshape(p, p)))) + eeta
 def omg_func(S):
     csp = hsym((cc @ S.P.reshape(-1)).reshape(p, p))
     return fr_ambient(
         (aaU @ S.U.reshape(-1) + icpt.tU.reshape(-1)).reshape(m, p),
         (aaV @ S.V.reshape(-1) + icpt.tV.reshape(-1)).reshape(n, p),
         csp + icpt.tP)
def test_covariance_deriv():
    # now test full:
    # do covariant derivatives
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = ComplexFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)
    
    S = man.rand()
    
    aaU = crandn(m*p, m*p)
    aaV = crandn(n*p, n*p)
    cc = crandn(p*p, p*p)
    icpt = man._rand_ambient()

    def omg_func(S):
        csp = hsym((cc @ S.P.reshape(-1)).reshape(p, p))
        return fr_ambient(
            (aaU @ S.U.reshape(-1) + icpt.tU.reshape(-1)).reshape(m, p),
            (aaV @ S.V.reshape(-1) + icpt.tV.reshape(-1)).reshape(n, p),
            csp + icpt.tP)

    xi = man.randvec(S)
    egrad = omg_func(S)
    ecsp = hsym((cc @ xi.tP.reshape(-1)).reshape(p, p))
    ehess = fr_ambient(
        (aaU @ xi.tU.reshape(-1)).reshape(m, p),
        (aaV @ xi.tV.reshape(-1)).reshape(n, p),
        ecsp)

    val1 = man.ehess2rhess(S, egrad, ehess, xi)

    def rgrad_func(W):
        return man.proj_g_inv(W, omg_func(W))

    if False:
        first = ehess
        a = man.J_g_inv(S, egrad)
        rgrad = man.proj_g_inv(S, egrad)
        second = man.D_g(
            S, xi, man.g_inv(S, egrad)).scalar_mul(-1)
        aout = man.solve_J_g_inv_Jst(S, a)
        third = man.proj(S, man.D_g_inv_Jst(S, xi, aout)).scalar_mul(-1)
        fourth = man.christoffel_form(S, xi, rgrad)
        val1a1 = man.proj_g_inv(S, first + second + fourth) + third
        print(check_zero(man._vec(val1-val1a1)))
    elif True:
        d_xi_rgrad = num_deriv_amb(man, S, xi, rgrad_func)
        rgrad = man.proj_g_inv(S, egrad)
        fourth = man.christoffel_form(S, xi, rgrad)
        val1a = man.proj(S, d_xi_rgrad) + man.proj_g_inv(S, fourth)
        print(check_zero(man._vec(val1-val1a)))

    # nabla_v_xi, dxi, cxxi
    val2a, _, _ = calc_covar_numeric(man, S, xi, omg_func)
    val2, _, _ = calc_covar_numeric(man, S, xi, rgrad_func)
    # val2_p = project(prj, val2)
    val2_p = man.proj(S, val2)
    # print(val1)
    # print(val2_p)
    print(check_zero(man._vec(val1)-man._vec(val2_p)))
    if True:
        H = xi
        valrangeA_ = ehess + man.g(S, man.D_proj(
            S, H, man.g_inv(S, egrad))) - man.D_g(
                S, H, man.g_inv(S, egrad)) +\
            man.christoffel_form(S, H, man.proj_g_inv(S, egrad))
        valrangeB = man.proj_g_inv(S, valrangeA_)
    valrange = man.ehess2rhess(S, egrad, ehess, xi)
    print(check_zero(man._vec(valrange)-man._vec(val2_p)))
    print(check_zero(man._vec(valrange)-man._vec(val1)))
    print(check_zero(man._vec(valrange)-man._vec(valrangeB)))