def test_christ_flat():
    """now test that christofel preserve metrics:
    on the flat space
    d_xi <v M v> = 2 <v M nabla_xi v>
     v = proj(W) @ (aa W + b)
    """
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = ComplexFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)
    
    S = man.rand()
    
    xi = man.randvec(S)
    aaU = crandn(m*p, m*p)
    bbU = crandn(m*p)

    aaV = crandn(n*p, n*p)
    bbV = crandn(n*p)
    
    cc = crandn(p*p, p*p)
    dd = hsym(crandn(p, p))
        
    def v_func_flat(S):
        # a function from the manifold
        # to ambient
        csp = hsym((cc @ S.P.reshape(-1)).reshape(p, p))
        
        return fr_ambient(
            (aaU @ S.U.reshape(-1) + bbU).reshape(m, p),
            (aaV @ S.V.reshape(-1) + bbV).reshape(n, p),
            csp + dd)

    vv = v_func_flat(S)  # vv is not horizontal
    dlt = 1e-7
    Snew = fr_point(
        S.U + dlt * xi.tU,
        S.V + dlt * xi.tV,
        S.P + dlt * xi.tP)
    vnew = v_func_flat(Snew)

    val = man.inner(S, vv)
    valnew = man.inner(Snew, vnew)
    d1 = (valnew - val)/dlt
    dv = (vnew - vv).scalar_mul(1/dlt)
    nabla_xi_v = dv + man.g_inv(
        S, man.christoffel_form(S, xi, vv))
    # not equal bc vv is not horizontal:
    nabla_xi_va = dv + man.g_inv(
        S, super(ComplexFixedRank, man).christoffel_form(S, xi, vv))
    print(check_zero(man._vec(nabla_xi_v) - man._vec(nabla_xi_va)))
    d2 = man.inner(S, vv, nabla_xi_v)

    print(d1)
    print(2*d2)
def calc_covar_numeric(man, S, xi, v_func):
    """ compute nabla on E dont do the metric
    lower index. So basically
    Nabla (Pi e).
    Thus, if we want to do Nabla Pi g_inv df
    We need to send g_inv df
    """

    def vv_func(W):
        return man.proj(W, v_func(W))
    
    vv = vv_func(S)

    dlt = 1e-7
    Snew = fr_point(S.U + dlt*xi.tU,
                    S.V + dlt*xi.tV,
                    S.P + dlt * xi.tP)
    vnew = vv_func(Snew)

    val = man.inner(S, vv)
    valnew = man.inner(
        Snew, vnew)
    d1 = (valnew - val)/dlt
    dv = (vnew - vv).scalar_mul(1/dlt)
    cx = man.christoffel_form(S, xi, vv)
    nabla_xi_v_up = dv + man.g_inv(S, cx)
    nabla_xi_v = man.proj(S, nabla_xi_v_up)
    
    if False:
        d2 = man.inner_product_amb(S, vv, nabla_xi_v)
        d2up = man.inner_product_amb(
            S, vv, nabla_xi_v_up)

        print(d1)
        print(2*d2up)
        print(2*d2)
    return nabla_xi_v, dv, cx
def test_geodesics():
    from scipy.linalg import expm
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 5
    n = 6
    p = 3
    man = ComplexFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)
    X = man.rand()

    alf = alpha[1]/alpha[0]
    gmm = gamma[1]/gamma[0]
    
    def calc_Christoffel_Gamma(man, X, xi, eta):
        dprj = man.D_proj(X, xi, eta)
        proj_christoffel = man.proj_g_inv(
            X, man.christoffel_form(X, xi, eta))
        return proj_christoffel - dprj
        
    eta = man.randvec(X)
    g1 = calc_Christoffel_Gamma(man, X, eta, eta)
    g2 = man.christoffel_gamma(X, eta, eta)
    print(man._vec(g1-g2))

    egrad = man._rand_ambient()
    print(man.base_inner_ambient(g1, egrad))
    print(man.rhess02_alt(X, eta, eta, egrad, 0))
    print(man.rhess02(X, eta, eta, egrad, man.zerovec(X)))

    t = 2
    AU = X.U.T.conj() @ eta.tU
    KU = eta.tU - X.U @ (X.U.T.conj() @ eta.tU)
    Up, RU = np.linalg.qr(KU)

    xU_mat = np.bmat([[2*alf*AU, -RU.T.conj()], [RU, zeros((p, p))]])
    Ut = np.bmat([X.U, Up]) @ expm(t*xU_mat)[:, :p] @ \
        expm(t*(1-2*alf)*AU)
    xU_d_mat = xU_mat[:, :p].copy()
    xU_d_mat[:p, :] += (1-2*alf) * AU
    Udt = np.bmat([X.U, Up]) @ expm(t*xU_mat) @ xU_d_mat @\
        expm(t*(1-2*alf)*AU)
    xU_dd_mat = xU_mat @ xU_d_mat + xU_d_mat @ ((1-2*alf)*AU)
    Uddt = np.bmat([X.U, Up]) @ expm(t*xU_mat) @ xU_dd_mat @\
        expm(t*(1-2*alf)*AU)

    AV = X.V.T.conj() @ eta.tV
    KV = eta.tV - X.V @ (X.V.T.conj() @ eta.tV)
    Vp, RV = np.linalg.qr(KV)

    xV_mat = np.bmat([[2*gmm*AV, -RV.T.conj()], [RV, zeros((p, p))]])
    Vt = np.bmat([X.V, Vp]) @ expm(t*xV_mat)[:, :p] @ \
        expm(t*(1-2*gmm)*AV)
    xV_d_mat = xV_mat[:, :p].copy()
    xV_d_mat[:p, :] += (1-2*gmm) * AV
    Vdt = np.bmat([X.V, Vp]) @ expm(t*xV_mat) @ xV_d_mat @\
        expm(t*(1-2*gmm)*AV)
    xV_dd_mat = xV_mat @ xV_d_mat + xV_d_mat @ ((1-2*gmm)*AV)
    Vddt = np.bmat([X.V, Vp]) @ expm(t*xV_mat) @ xV_dd_mat @\
        expm(t*(1-2*gmm)*AV)
    
    sqrtP = X.evec @ np.diag(np.sqrt(X.evl)) @ X.evec.T.conj()
    isqrtP = X.evec @ np.diag(1/np.sqrt(X.evl)) @ X.evec.T.conj()
    Pinn = t*[email protected]@isqrtP
    ePinn = expm(Pinn)
    Pt = sqrtP@ePinn@sqrtP
    Pdt = eta.tP@isqrtP@ePinn@sqrtP
    Pddt = eta.tP@isqrtP@ ePinn@[email protected]
    
    Xt = fr_point(np.array(Ut),
                  np.array(Vt),
                  np.array(Pt))
    Xdt = fr_ambient(np.array(Udt),
                     np.array(Vdt),
                     np.array(Pdt))
    Xddt = fr_ambient(np.array(Uddt),
                      np.array(Vddt),
                      np.array(Pddt))
    gcheck = Xddt + calc_Christoffel_Gamma(man, Xt, Xdt, Xdt)
    
    print(man._vec(gcheck))
    Xt1 = man.exp(X, t*eta)
    print((Xt1.U - Xt.U))
    print((Xt1.V - Xt.V))
    print((Xt1.P - Xt.P))
def test_rhess_02():
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = ComplexFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)

    S = man.rand()
    # simple function. Distance to a given matrix
    # || S - A||_F^2 Basically SVD
    A = crandn(m, n)

    def f(S):
        diff = (A - S.U @ S.P @ S.V.T.conj())
        return rtrace(diff @ diff.T.conj())

    def df(S):
        return fr_ambient(-2*A @ S.V @ S.P,
                          -2*A.T.conj() @ S.U @S.P,
                          2*(S.P-S.U.T.conj() @ A @ S.V))
    
    def ehess_vec(S, xi):
        return fr_ambient(-2*A @ (xi.tV @ S.P + S.V @ xi.tP),
                          -2*A.T.conj() @ (xi.tU @S.P + [email protected]),
                          2*(xi.tP - xi.tU.T.conj()@[email protected] -
                             S.U.T.conj()@[email protected]))
    
    def ehess_form(S, xi, eta):
        ev = ehess_vec(S, xi)
        return rtrace(ev.tU.T.conj() @ eta.tU) +\
            rtrace(ev.tV.T.conj() @ eta.tV) +\
            rtrace(ev.tP.T.conj() @ eta.tP)
    
    xxi = man.randvec(S)
    dlt = 1e-8
    Snew = fr_point(
        S.U+dlt*xxi.tU,
        S.V+dlt*xxi.tV,
        S.P + dlt*xxi.tP)
    d1 = (f(Snew) - f(S))/dlt
    d2 = df(S)
    print(d1 - man.base_inner_ambient(d2,  xxi))

    dv1 = (df(Snew) - df(S)).scalar_mul(1/dlt)
    dv2 = ehess_vec(S, xxi)
    print(man._vec(dv1-dv2))
    
    eeta = man.randvec(S)
    d1 = man.base_inner_ambient((df(Snew) - df(S)), eeta) / dlt
    ehess_val = ehess_form(S, xxi, eeta)
    
    print(man.base_inner_ambient(dv2, eeta))
    print(d1, ehess_val, d1-ehess_val)

    # now check the formula: ehess = xi (eta_func(f)) - <D_xi eta, df(Y)>
    # promote eta to a vector field.

    mU1 = crandn(m, m)
    mV1 = crandn(n, n)
    m2 = crandn(p, p)
    m_p = crandn(p*p, p*p)

    def eta_field(Sin):
        return man.proj(S, fr_ambient(
            mU1 @ (Sin.U - S.U) @ m2,
            mV1 @ (Sin.V - S.V) @ m2,
            hsym((m_p @ (Sin.P - S.P).reshape(-1)).reshape(p, p)))) + eeta

    # xietaf: should go to ehess(xi, eta) + df(Y) @ etafield)
    xietaf = (man.base_inner_ambient(df(Snew), eta_field(Snew)) -
              man.base_inner_ambient(df(S), eta_field(S))) / dlt
    # appy eta_func to f: should go to tr(m1 @ xxi @ m2 @ df(Y).T.conj())
    Dxietaf = man.base_inner_ambient(
        (eta_field(Snew) - eta_field(S)), df(S))/dlt
    # this is ehess. should be same as d1 or ehess_val
    print(xietaf-Dxietaf)
    print(xietaf-Dxietaf-ehess_val)

    # now check: rhess. Need to make sure xi, eta in the tangent space.
    # first compare this with numerical differentiation
    xi1 = man.proj(S, xxi)
    eta1 = man.proj(S, eeta)
    egvec = df(S)
    ehvec = ehess_vec(S, xi1)
    rhessvec = man.ehess2rhess(S, egvec, ehvec, xi1)

    # check it numerically:
    def rgrad_func(Y):
        return man.proj_g_inv(Y, df(Y))
    
    # val2a, _, _ = calc_covar_numeric_raw(man, W, xi1, df)
    val2, _, _ = calc_covar_numeric(man, S, xi1, rgrad_func)
    val2_p = man.proj(S, val2)
    # print(rhessvec)
    # print(val2_p)
    print(man._vec(rhessvec-val2_p))
    rhessval = man.inner(S, rhessvec, eta1)
    print(man.inner(S, val2, eta1))
    print(rhessval)

    # check symmetric:
    ehvec_e = ehess_vec(S, eta1)
    rhessvec_e = man.ehess2rhess(S, egvec, ehvec_e, eta1)
    rhessval_e = man.inner(S, rhessvec_e, xi1)
    print(rhessval_e)
    # the above computed inner_prod(Nabla_xi Pi * df, eta)
    # in the following check. Extend eta1 to eta_proj
    # (Pi Nabla_hat Pi g_inv df, g eta)
    # = D_xi (Pi g_inv df, g eta) - (Pi g_inv df g Pi Nabla_hat eta)
    
    def eta_proj(S):
        return man.proj(S, eta_field(S))
    print(check_zero(man._vec(eta1-eta_proj(S))))
    
    e1 = man.inner(S, man.proj_g_inv(S, df(S)), eta_proj(S))
    e1a = man.base_inner_ambient(df(S), eta_proj(S))
    print(e1, e1a, e1-e1a)
    Snew = fr_point(
        S.U + dlt*xi1.tU,
        S.V + dlt*xi1.tV,
        S.P + dlt*xi1.tP)
    e2 = man.inner(Snew, man.proj_g_inv(Snew, df(Snew)), eta_proj(Snew))
    e2a = man.base_inner_ambient(df(Snew), eta_proj(Snew))
    print(e2, e2a, e2-e2a)
    
    first = (e2 - e1)/dlt
    first1 = (man.base_inner_ambient(df(Snew), eta_proj(Snew)) -
              man.base_inner_ambient(df(S), eta_proj(S)))/dlt
    print(first-first1)
    
    val3, _, _ = calc_covar_numeric(man, S, xi1, eta_proj)
    second = man.inner(S, man.proj_g_inv(S, df(S)), man.proj(S, val3))
    second2 = man.inner(S, man.proj_g_inv(S, df(S)), val3)
    print(second, second2, second-second2)
    print('same as rhess_val %f' % (first-second))
def num_deriv_amb(man, S, xi, func, dlt=1e-7):
    Snew = fr_point(S.U + dlt*xi.tU,
                    S.V + dlt*xi.tV,
                    S.P + dlt*xi.tP)
    return (func(Snew) - func(S)).scalar_mul(1/dlt)
def test_chris_vectorfields():
    # now test that it works on embedded metrics
    # we test that D_xi (eta g eta) = 2(eta g nabla_xi eta)
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = ComplexFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)

    S0 = man.rand()
    aaU = crandn(m*p, m*p)
    intcU = crandn(m*p)
    aaV = crandn(n*p, n*p)
    intcV = crandn(n*p)
    
    cc = crandn(p*p, p*p)
    p_intc = hsym(crandn(p, p))

    inct_xi = man._rand_ambient()
    aa_xiU = crandn(m*p, m*p)
    aa_xiV = crandn(n*p, n*p)
    cc_xi = crandn(p*p, p*p)
    
    def v_func(S):
        # a function from the manifold
        # to ambient
        csp = hsym((cc @ (S.P-S0.P).reshape(-1)).reshape(p, p))
        
        return man.proj(S, fr_ambient(
            (aaU @ (S.U-S0.U).reshape(-1) + intcU).reshape(m, p),
            (aaV @ (S.V-S0.V).reshape(-1) + intcV).reshape(n, p),
            csp + p_intc))

    SS = fr_point(S0.U, S0.V, S0.P)
    xi = man.proj(SS, inct_xi)

    nabla_xi_v, dv, cxv = calc_covar_numeric(
        man, SS, xi, v_func)

    def xi_func(S):
        csp_xi = hsym((cc_xi @ (S.P-S0.P).reshape(-1)).reshape(p, p))
        xi_amb = fr_ambient(
            (aa_xiU @ (S.U-S0.U).reshape(-1) +
             inct_xi.tU.reshape(-1)).reshape(m, p),
            (aa_xiV @ (S.V-S0.V).reshape(-1) +
             inct_xi.tV.reshape(-1)).reshape(n, p),
            csp_xi + inct_xi.tP)
        return man.proj(S, xi_amb)

    vv = v_func(SS)

    nabla_v_xi, dxi, cxxi = calc_covar_numeric(
        man, SS, vv, xi_func)
    diff = nabla_xi_v - nabla_v_xi
    print(diff.tU, diff.tV, diff.tP)
    # now do Lie bracket:
    dlt = 1e-7
    SnewXi = fr_point(SS.U+dlt*xi.tU,
                      SS.V+dlt*xi.tV,
                      SS.P+dlt*xi.tP)
    Snewvv = fr_point(SS.U+dlt*vv.tU,
                      SS.V+dlt*vv.tV,
                      SS.P+dlt*vv.tP)
    vnewxi = v_func(SnewXi)
    xnewv = xi_func(Snewvv)
    dxiv = (vnewxi - vv).scalar_mul(1/dlt)
    dvxi = (xnewv - xi).scalar_mul(1/dlt)
    diff2 = man.proj(SS, dxiv-dvxi)
    print(check_zero(man._vec(diff) - man._vec(diff2)))
def test_all_projections():
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 1)[0] * .02
    m = 4
    n = 5
    d = 3
    man = ComplexFixedRank(m, n, d, alpha=alpha, beta=beta, gamma=gamma)
    X = man.rand()

    test_inner(man, X)
    test_J(man, X)
            
    # now check metric, Jst etc
    # check Jst: vectorize the operator J then compare Jst with jmat.T.conj()
    # test projection
    test_projection(man, X)
    # now diff projection

    for i in range(1):
        e = man._rand_ambient()
        X1 = man.rand()
        xi = man.randvec(X1)
        dlt = 1e-7
        X2 = fr_point(
            X1.U + dlt*xi.tU,
            X1.V + dlt*xi.tV,
            X1.P+dlt*xi.tP)

        # S = psd_point(S1.Y, S1.P)
        """
        Dp, Dm = calc_D(man, X1, e)
        Dp2, Dm2 = calc_D(man, X2, e)
        omg = e
        al1 = man.alpha[1]
        gm1 = man.gamma[1]
        U, V, P, Piv = (X1.U, X1.V, X1.P, X1.Pinv)
        agi = 1/(al1+gm1)
        DxiLDp = agi*(xi.tP @ Dp @ Piv + Piv @ Dp @ xi.tP -
                      P @ Dp @ Piv @ xi.tP @ Piv -
                      Piv @ xi.tP @ Piv @ Dp @ P)
        
        def LP(X, Dp):
            return (1/man.beta - 2*agi)*Dp +\
                agi*(X.Pinv @Dp @ X.P + X.P@ Dp @ X.Pinv)
        print((LP(X2, Dp) - LP(X1, Dp))/dlt)

        ddin = agi*(
            al1*(xi.tU.T.conj()@omg.tU@P - [email protected]()@omg.tU +
                 U.T.conj() @[email protected] - [email protected]()@omg.tU) +
            gm1*(xi.tV.T.conj()@omg.tV@P - [email protected]()@omg.tV +
                 V.T.conj() @[email protected] - [email protected]()@omg.tV)) - DxiLDp

        Ddp = extended_lyapunov(1/man.beta, agi, P, sym(ddin), X1.evl, X1.evec)
        Ddm = agi*asym(xi.tV.T.conj()@omg.tV - xi.tU.T.conj()@omg.tU)
        print(check_zero(Ddm - (Dm2-Dm)/dlt))
        print(check_zero(Ddp - (Dp2-Dp)/dlt))
        """
        d1 = (man.proj(X2, e) - man.proj(X1, e)).scalar_mul(1/dlt)
        d2 = man.D_proj(X1, xi, e)
        print(check_zero(man._vec(d1-d2)))
    
    for i in range(20):
        Uran = man._rand_ambient()
        Upr = man.proj(X, man.g_inv(X, Uran))
        Upr2 = man.proj_g_inv(X, Uran)
        print(check_zero(man._vec(Upr)-man._vec(Upr2)))
                
    # derives metrics
    for ii in range(10):
        X1 = man.rand()
        xi = man.randvec(X1)
        omg1 = man._rand_ambient()
        omg2 = man._rand_ambient()
        dlt = 1e-7
        X2 = fr_point(
            X1.U + dlt*xi.tU,
            X1.V + dlt*xi.tV,
            X1.P+dlt*xi.tP)
        p1 = man.inner(X1, omg1, omg2)
        p2 = man.inner(X2, omg1, omg2)
        der1 = (p2-p1)/dlt
        der2 = man.base_inner_ambient(
            man.D_g(X1, xi, omg2), omg1)
        print(der1-der2)

    # cross term for christofel
    for i in range(10):
        X1 = man.rand()
        xi = man.randvec(X1)
        eta1 = man.randvec(X1)
        eta2 = man.randvec(X1)
        dr1 = man.D_g(X1, xi, eta1)
        x12 = man.contract_D_g(X1, eta1, eta2)

        p1 = man.base_inner_ambient(dr1, eta2)
        p2 = man.base_inner_ambient(x12, xi)
        print(p1, p2, p1-p2)

    # now test christofel:
    # two things: symmetric on vector fields
    # and christofel relation
    # in the base metric
    for i in range(10):
        S1 = man.rand()
        xi = man.randvec(S1)
        eta1 = man.randvec(S1)
        eta2 = man.randvec(S1)
        p1 = man.proj_g_inv(S1, man.christoffel_form(S1, xi, eta1))
        p2 = man.proj_g_inv(S1, man.christoffel_form(S1, eta1, xi))
        print(check_zero(man._vec(p1)-man._vec(p2)))
        v1 = man.base_inner_ambient(
            man.christoffel_form(S1, eta1, eta2), xi)
        v2 = man.base_inner_ambient(man.D_g(S1, eta1, eta2), xi)
        v3 = man.base_inner_ambient(man.D_g(S1, eta2, eta1), xi)
        v4 = man.base_inner_ambient(man.D_g(S1, xi, eta1), eta2)
        v5 = man.base_inner_ambient(man.contract_D_g(S1, eta1, eta2), xi)
        print(v1, 0.5*(v2+v3-v4), v1-0.5*(v2+v3-v4))