def solveNTgN(X, Bo, Co, Do):
        Dp, Dm = sym(Do), asym(Do)
        Dm_ = 1 / (al[1] * gm[1] * (al[1] + gm[1])) * Dm
        Dp_ = extended_lyapunov(1 / bt, 1 / (al[1] + gm[1]), X.P,
                                X.P @ Dp @ X.P)

        return 1 / al[0] * Bo, 1 / gm[0] * Co, Dp_ + Dm_
 def xi_func(S):
     csp_xi = sym((cc_xi @ (S.P - S0.P).reshape(-1)).reshape(p, p))
     xi_amb = fr_ambient((aa_xiU @ (S.U - S0.U).reshape(-1) +
                          inct_xi.tU.reshape(-1)).reshape(m, p),
                         (aa_xiV @ (S.V - S0.V).reshape(-1) +
                          inct_xi.tV.reshape(-1)).reshape(n, p),
                         csp_xi + inct_xi.tP)
     return man.proj(S, xi_amb)
    def v_func_flat(S):
        # a function from the manifold
        # to ambient
        csp = sym((cc @ S.P.reshape(-1)).reshape(p, p))

        return fr_ambient((aaU @ S.U.reshape(-1) + bbU).reshape(m, p),
                          (aaV @ S.V.reshape(-1) + bbV).reshape(n,
                                                                p), csp + dd)
    def NTgN_opt(X, B, C, D):
        Piv = X.Pinv
        Dp, Dm = sym(D), asym(D)
        Dp_ = (1 / bt - 2 / (al[1] + gm[1])) * Piv @ Dp @ Piv + 1 / (
            al[1] + gm[1]) * (X.Pinv @ X.Pinv @ Dp + Dp @ X.Pinv @ X.Pinv)
        Dm_ = al[1] * gm[1] * (al[1] + gm[1]) * Dm

        return al[0] * B, gm[0] * C, Dp_ + Dm_
 def get_D2(omg):
     UTomg = X.U.T @ omg.tU
     VTomg = X.V.T @ omg.tV
     Piv = X.Pinv
     D2 = sym(X.Pinv @ omg.tP @ X.Pinv + 1 / (al[1] + gm[1]) *
              (al[1] * (Piv @ UTomg - UTomg @ Piv) + gm[1] *
               (Piv @ VTomg - VTomg @ Piv)))
     return D2
    def v_func(S):
        # a function from the manifold
        # to ambient
        csp = sym((cc @ (S.P - S0.P).reshape(-1)).reshape(p, p))

        return man.proj(
            S,
            fr_ambient((aaU @ (S.U - S0.U).reshape(-1) + intcU).reshape(m, p),
                       (aaV @ (S.V - S0.V).reshape(-1) + intcV).reshape(n, p),
                       csp + p_intc))
def test_christ_flat():
    """now test that christofel preserve metrics:
    on the flat space
    d_xi <v M v> = 2 <v M nabla_xi v>
     v = proj(W) @ (aa W + b)
    """
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = RealFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)

    S = man.rand()

    xi = man.randvec(S)
    aaU = randn(m * p, m * p)
    bbU = randn(m * p)

    aaV = randn(n * p, n * p)
    bbV = randn(n * p)

    cc = randn(p * p, p * p)
    dd = sym(randn(p, p))

    def v_func_flat(S):
        # a function from the manifold
        # to ambient
        csp = sym((cc @ S.P.reshape(-1)).reshape(p, p))

        return fr_ambient((aaU @ S.U.reshape(-1) + bbU).reshape(m, p),
                          (aaV @ S.V.reshape(-1) + bbV).reshape(n,
                                                                p), csp + dd)

    vv = v_func_flat(S)  # vv is not horizontal
    dlt = 1e-7
    Snew = fr_point(S.U + dlt * xi.tU, S.V + dlt * xi.tV, S.P + dlt * xi.tP)
    vnew = v_func_flat(Snew)

    val = man.inner(S, vv)
    valnew = man.inner(Snew, vnew)
    d1 = (valnew - val) / dlt
    dv = (vnew - vv).scalar_mul(1 / dlt)
    nabla_xi_v = dv + man.g_inv(S, man.christoffel_form(S, xi, vv))
    # not equal bc vv is not horizontal:
    nabla_xi_va = dv + man.g_inv(
        S,
        super(RealFixedRank, man).christoffel_form(S, xi, vv))
    print(check_zero(man._vec(nabla_xi_v) - man._vec(nabla_xi_va)))
    d2 = man.inner(S, vv, nabla_xi_v)

    print(d1)
    print(2 * d2)
def N(man, X, B, C, D):
    al, bt, gm = (man.alpha, man.beta, man.gamma)
    U, V, Piv = (X.U, X.V, X.Pinv)
    Dm = asym(D)
    Dp = sym(D)
    bkPivDp = Piv @ Dp - Dp @ Piv
    U0 = null_space(X._U.T)
    V0 = null_space(X._V.T)
    U0B = U0 @ B
    V0C = V0 @ C
    return fr_ambient(U @ (-gm[1] * Dm + 1 / (al[1] + gm[1]) * bkPivDp) + U0B,
                      V @ (al[1] * Dm + 1 / (al[1] + gm[1]) * bkPivDp) + V0C,
                      1 / bt * Dp)
 def eta_field(Sin):
     return man.proj(
         S,
         fr_ambient(mU1 @ (Sin.U - S.U) @ m2, mV1 @ (Sin.V - S.V) @ m2,
                    sym((m_p @ (Sin.P - S.P).reshape(-1)).reshape(
                        p, p)))) + eeta
 def omg_func(S):
     csp = sym((cc @ S.P.reshape(-1)).reshape(p, p))
     return fr_ambient(
         (aaU @ S.U.reshape(-1) + icpt.tU.reshape(-1)).reshape(m, p),
         (aaV @ S.V.reshape(-1) + icpt.tV.reshape(-1)).reshape(n, p),
         csp + icpt.tP)
def test_covariance_deriv():
    # now test full:
    # do covariant derivatives
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = RealFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)

    S = man.rand()

    aaU = randn(m * p, m * p)
    aaV = randn(n * p, n * p)
    cc = randn(p * p, p * p)
    icpt = man._rand_ambient()

    def omg_func(S):
        csp = sym((cc @ S.P.reshape(-1)).reshape(p, p))
        return fr_ambient(
            (aaU @ S.U.reshape(-1) + icpt.tU.reshape(-1)).reshape(m, p),
            (aaV @ S.V.reshape(-1) + icpt.tV.reshape(-1)).reshape(n, p),
            csp + icpt.tP)

    xi = man.randvec(S)
    egrad = omg_func(S)
    ecsp = sym((cc @ xi.tP.reshape(-1)).reshape(p, p))
    ehess = fr_ambient((aaU @ xi.tU.reshape(-1)).reshape(m, p),
                       (aaV @ xi.tV.reshape(-1)).reshape(n, p), ecsp)

    val1 = man.ehess2rhess(S, egrad, ehess, xi)

    def rgrad_func(W):
        return man.proj_g_inv(W, omg_func(W))

    if False:
        first = ehess
        a = man.J_g_inv(S, egrad)
        rgrad = man.proj_g_inv(S, egrad)
        second = man.D_g(S, xi, man.g_inv(S, egrad)).scalar_mul(-1)
        aout = man.solve_J_g_inv_Jst(S, a)
        third = man.proj(S, man.D_g_inv_Jst(S, xi, aout)).scalar_mul(-1)
        fourth = man.christoffel_form(S, xi, rgrad)
        val1a1 = man.proj_g_inv(S, first + second + fourth) + third
        print(check_zero(man._vec(val1 - val1a1)))
    elif True:
        d_xi_rgrad = num_deriv_amb(man, S, xi, rgrad_func)
        rgrad = man.proj_g_inv(S, egrad)
        fourth = man.christoffel_form(S, xi, rgrad)
        val1a = man.proj(S, d_xi_rgrad) + man.proj_g_inv(S, fourth)
        print(check_zero(man._vec(val1 - val1a)))

    # nabla_v_xi, dxi, cxxi
    val2a, _, _ = calc_covar_numeric(man, S, xi, omg_func)
    val2, _, _ = calc_covar_numeric(man, S, xi, rgrad_func)
    # val2_p = project(prj, val2)
    val2_p = man.proj(S, val2)
    # print(val1)
    # print(val2_p)
    print(check_zero(man._vec(val1) - man._vec(val2_p)))
    if True:
        H = xi
        valrangeA_ = ehess + man.g(S, man.D_proj(
            S, H, man.g_inv(S, egrad))) - man.D_g(
                S, H, man.g_inv(S, egrad)) +\
            man.christoffel_form(S, H, man.proj_g_inv(S, egrad))
        valrangeB = man.proj_g_inv(S, valrangeA_)
    valrange = man.ehess2rhess(S, egrad, ehess, xi)
    print(check_zero(man._vec(valrange) - man._vec(val2_p)))
    print(check_zero(man._vec(valrange) - man._vec(val1)))
    print(check_zero(man._vec(valrange) - man._vec(valrangeB)))
def test_chris_vectorfields():
    # now test that it works on embedded metrics
    # we test that D_xi (eta g eta) = 2(eta g nabla_xi eta)
    alpha = randint(1, 10, 2) * .1
    gamma = randint(1, 10, 2) * .1
    beta = randint(1, 10, 2)[0] * .1

    m = 4
    n = 5
    p = 3
    man = RealFixedRank(m, n, p, alpha=alpha, beta=beta, gamma=gamma)

    S0 = man.rand()
    aaU = randn(m * p, m * p)
    intcU = randn(m * p)
    aaV = randn(n * p, n * p)
    intcV = randn(n * p)

    cc = randn(p * p, p * p)
    p_intc = sym(randn(p, p))

    inct_xi = man._rand_ambient()
    aa_xiU = randn(m * p, m * p)
    aa_xiV = randn(n * p, n * p)
    cc_xi = randn(p * p, p * p)

    def v_func(S):
        # a function from the manifold
        # to ambient
        csp = sym((cc @ (S.P - S0.P).reshape(-1)).reshape(p, p))

        return man.proj(
            S,
            fr_ambient((aaU @ (S.U - S0.U).reshape(-1) + intcU).reshape(m, p),
                       (aaV @ (S.V - S0.V).reshape(-1) + intcV).reshape(n, p),
                       csp + p_intc))

    SS = fr_point(S0.U, S0.V, S0.P)
    xi = man.proj(SS, inct_xi)

    nabla_xi_v, dv, cxv = calc_covar_numeric(man, SS, xi, v_func)

    def xi_func(S):
        csp_xi = sym((cc_xi @ (S.P - S0.P).reshape(-1)).reshape(p, p))
        xi_amb = fr_ambient((aa_xiU @ (S.U - S0.U).reshape(-1) +
                             inct_xi.tU.reshape(-1)).reshape(m, p),
                            (aa_xiV @ (S.V - S0.V).reshape(-1) +
                             inct_xi.tV.reshape(-1)).reshape(n, p),
                            csp_xi + inct_xi.tP)
        return man.proj(S, xi_amb)

    vv = v_func(SS)

    nabla_v_xi, dxi, cxxi = calc_covar_numeric(man, SS, vv, xi_func)
    diff = nabla_xi_v - nabla_v_xi
    print(diff.tU, diff.tV, diff.tP)
    # now do Lie bracket:
    dlt = 1e-7
    SnewXi = fr_point(SS.U + dlt * xi.tU, SS.V + dlt * xi.tV,
                      SS.P + dlt * xi.tP)
    Snewvv = fr_point(SS.U + dlt * vv.tU, SS.V + dlt * vv.tV,
                      SS.P + dlt * vv.tP)
    vnewxi = v_func(SnewXi)
    xnewv = xi_func(Snewvv)
    dxiv = (vnewxi - vv).scalar_mul(1 / dlt)
    dvxi = (xnewv - xi).scalar_mul(1 / dlt)
    diff2 = man.proj(SS, dxiv - dvxi)
    print(check_zero(man._vec(diff) - man._vec(diff2)))
def test_J(man, X):
    from scipy.linalg import null_space
    al = man.alpha
    bt = man.beta
    gm = man.gamma
    p = man.p
    # U, V, P, Piv = (X.U, X.V, X.P, X.Pinv)

    jjmat = np.zeros((4 * p * p, man.tdim))
    for i in range(man.tdim):
        Ux = zeros(man.tdim)
        Ux[i] = 1
        omg = man._unvec(Ux)
        jjmat[:, i] = np.concatenate([
            stU(X, omg).reshape(-1),
            stV(X, omg).reshape(-1),
            symP(X, omg).reshape(-1),
            Hz(man, X, omg).reshape(-1)
        ])

    # nsp = null_space(jjmat)
    # prj = nsp @ la.solve(nsp.T @ nsp, nsp.T)

    omg = man._rand_ambient()
    eta = man.proj(X, omg)

    def rand_vertical():
        oo = randn(man.p, man.p)
        oo = oo - oo.T
        return fr_ambient(X.U @ oo, X.V @ oo, -oo @ X.P + X.P @ oo)

    vv = rand_vertical()
    print(man.inner(X, vv, eta))
    nmat = make_N_mat(man, X)
    gmat = make_g_mat(man, X)
    NTgN = nmat.T @ gmat @ nmat
    bcd = nmat.T @ gmat @ man._vec(omg)
    m, n, p = (man.m, man.n, man.p)
    B = bcd[:(m - p) * p].reshape(m - p, p)
    C = bcd[(m - p) * p:(m + n - 2 * p) * p].reshape(n - p, p)
    D = bcd[(m + n - 2 * p) * p:].reshape(p, p)
    Dp = sym(D)
    Dm = asym(D)
    Dm2 = al[1] * gm[1] * asym(X.V.T @ omg.tV - X.U.T @ omg.tU)
    print(Dm - Dm2)
    U0 = null_space(X._U.T)
    V0 = null_space(X._V.T)
    B2 = al[0] * U0.T @ omg.tU
    print(B - B2)
    C2 = gm[0] * V0.T @ omg.tV
    print(C - C2)
    Dmsolve = Dm2 / (al[1] + gm[1]) / al[1] / gm[1]
    print(Dmsolve - (1 / (al[1] + gm[1])) *
          (asym(X.V.T @ omg.tV - X.U.T @ omg.tU)))
    Dpsolve = extended_lyapunov(1 / bt, 1 / (al[1] + gm[1]), X.P,
                                X.P @ Dp @ X.P)
    Drecv = (1 / bt - 2 / (al[1] + gm[1])) * X.Pinv @ Dpsolve @ X.Pinv + 1 / (
        al[1] + gm[1]) * (X.Pinv @ X.Pinv @ Dpsolve +
                          Dpsolve @ X.Pinv @ X.Pinv)
    print(Drecv - Dp)

    def get_D2(omg):
        UTomg = X.U.T @ omg.tU
        VTomg = X.V.T @ omg.tV
        Piv = X.Pinv
        D2 = sym(X.Pinv @ omg.tP @ X.Pinv + 1 / (al[1] + gm[1]) *
                 (al[1] * (Piv @ UTomg - UTomg @ Piv) + gm[1] *
                  (Piv @ VTomg - VTomg @ Piv)))
        return D2

    Dp2 = get_D2(omg)
    print(check_zero(Dp - Dp2))

    def NTgN_opt(X, B, C, D):
        Piv = X.Pinv
        Dp, Dm = sym(D), asym(D)
        Dp_ = (1 / bt - 2 / (al[1] + gm[1])) * Piv @ Dp @ Piv + 1 / (
            al[1] + gm[1]) * (X.Pinv @ X.Pinv @ Dp + Dp @ X.Pinv @ X.Pinv)
        Dm_ = al[1] * gm[1] * (al[1] + gm[1]) * Dm

        return al[0] * B, gm[0] * C, Dp_ + Dm_

    def solveNTgN(X, Bo, Co, Do):
        Dp, Dm = sym(Do), asym(Do)
        Dm_ = 1 / (al[1] * gm[1] * (al[1] + gm[1])) * Dm
        Dp_ = extended_lyapunov(1 / bt, 1 / (al[1] + gm[1]), X.P,
                                X.P @ Dp @ X.P)

        return 1 / al[0] * Bo, 1 / gm[0] * Co, Dp_ + Dm_

    def testNTgN(man, X):
        m, n, p = (man.m, man.n, man.p)
        B0 = randn(m - p, p)
        C0 = randn(n - p, p)
        D0 = randn(p, p)
        out1 = NTgN @ np.concatenate(
            [B0.reshape(-1), C0.reshape(-1),
             D0.reshape(-1)])
        out2a = NTgN_opt(X, B0, C0, D0)
        out2 = np.concatenate(
            [out2a[0].reshape(-1), out2a[1].reshape(-1), out2a[2].reshape(-1)])
        print(check_zero(out1 - out2))
        out2b = solveNTgN(X, *out2a)
        print(check_zero(out2b[2] - D0))
        print(check_zero(out2b[1] - C0))
        print(check_zero(out2b[0] - B0))

    Bs, Cs, Ds = solveNTgN(X, B, C, D)
    # Dsm = asym(Ds)
    # Dsp = sym(Ds)
    Bf = U0.T @ omg.tU
    Cf = V0.T @ omg.tV
    print(check_zero(Bf - Bs))
    print(check_zero(Cf - Cs))
    Dfm = 1 / (al[1] + gm[1]) * asym(X.V.T @ omg.tV - X.U.T @ omg.tU)
    UTomg = X.U.T @ omg.tU
    VTomg = X.V.T @ omg.tV
    Dfp = extended_lyapunov(
        1 / bt, 1 / (al[1] + gm[1]), X.P,
        sym(omg.tP + 1 / (al[1] + gm[1]) *
            (al[1] * (UTomg @ X.P - X.P @ UTomg) + gm[1] *
             (VTomg @ X.P - X.P @ VTomg))), X.evl, X.evec)
    ee1 = N(man, X, Bf, Cf, Dfp + Dfm)
    # print(check_zero(Dsm + Dsp - ee1.tP))
    print(check_zero(Ds - Dfp - Dfm))
    print(check_zero(man._vec(ee1 - eta)))