def test_christ_flat():
    """now test that christofel preserve metrics:
    on the flat space
    d_xi <v M v> = 2 <v M nabla_xi v>
     v = proj(W) @ (aa W + b)
    """
    dvec = np.array([10, 3, 2, 3])
    p = dvec.shape[0] - 1
    alpha = randint(1, 10, (p, p + 1)) * .1
    man = ComplexFlag(dvec, alpha=alpha)
    Y = man.rand()
    n = man.n
    d = man.d

    xi = man.randvec(Y)
    aa = crandn(n * d, n * d)
    bb = crandn(n * d)

    def v_func_flat(Y):
        return (aa @ Y.reshape(-1) + bb).reshape(n, d)

    vv = v_func_flat(Y)
    dlt = 1e-7
    Ynew = Y + dlt * xi
    vnew = v_func_flat(Ynew)

    val = man.inner(Y, vv, vv)
    valnew = man.inner(Ynew, vnew, vnew)
    d1 = (valnew - val) / dlt
    dv = (vnew - vv) / dlt
    nabla_xi_v = dv + man.g_inv(Y, man.christoffel_form(Y, xi, vv))
    d2 = man.inner(Y, vv, nabla_xi_v)

    print(d1)
    print(2 * d2)
def test_rhess_02():
    np.random.seed(0)
    dvec = np.array([10, 3, 2, 3])
    p = dvec.shape[0] - 1
    alpha = randint(1, 10, (p, p + 1)) * .1
    man = ComplexFlag(dvec, alpha=alpha)
    n = man.n
    d = man.d

    Y = man.rand()
    UU = {}
    p = alpha.shape[0]
    VV = {}
    gidx = man._g_idx

    for rr in range(p):
        UU[rr] = make_sym_pos(n)
        VV[rr] = crandn(n, dvec[rr + 1])

    def f(Y):
        ss = 0
        for rr in range(p):
            br, er = gidx[rr + 1]
            wr = Y[:, br:er]
            ss += trace(UU[rr] @ wr @ wr.T.conjugate()).real
        return ss

    def df(W):
        ret = np.zeros_like(W)
        for rr in range(p):
            br, er = gidx[rr + 1]
            wr = W[:, br:er]
            ret[:, br:er] += 2 * UU[rr] @ wr
        return ret

    def ehess_form(W, xi, eta):
        ss = 0
        for rr in range(p):
            br, er = gidx[rr + 1]
            ss += 2 * trace(
                UU[rr] @ xi[:, br:er] @ eta[:, br:er].T.conjugate()).real
        return ss

    def ehess_vec(W, xi):
        ret = np.zeros_like(W)
        for rr in range(p):
            br, er = gidx[rr + 1]
            ret[:, br:er] += 2 * UU[rr] @ xi[:, br:er]
        return ret

    xxi = crandn(n, d)
    dlt = 1e-8
    Ynew = Y + dlt * xxi
    d1 = (f(Ynew) - f(Y)) / dlt
    d2 = df(Y)
    print(d1 - trace(d2 @ xxi.T.conjugate()).real)

    eeta = crandn(n, d)

    d1 = trace((df(Ynew) - df(Y)) @ eeta.T.conjugate()).real / dlt
    ehess_val = ehess_form(Y, xxi, eeta)
    # ehess_val2 = ehess_form(Y, eeta, xxi)
    dv2 = ehess_vec(Y, xxi)
    print(trace(dv2 @ eeta.T.conjugate()).real)
    print(d1, ehess_val, d1 - ehess_val)

    # now check the formula: ehess = xi (eta_func(f)) - <D_xi eta, df(Y)>
    # promote eta to a vector field.

    m1 = crandn(n, n)
    m2 = crandn(d, d)

    def eta_field(Yin):
        return m1 @ (Yin - Y) @ m2 + eeta

    # xietaf: should go to ehess(xi, eta) + df(Y) @ etafield)
    xietaf = trace(
        df(Ynew) @ eta_field(Ynew).T.conjugate() -
        df(Y) @ eta_field(Y).T.conjugate()).real / dlt
    # appy eta_func to f: should go to tr(m1 @ xxi @ m2 @ df(Y).T.conjugate())
    Dxietaf = trace(
        (eta_field(Ynew) - eta_field(Y)) @ df(Y).T.conjugate()).real / dlt
    # this is ehess. should be same as d1 or ehess_val
    print(xietaf - Dxietaf)
    print(xietaf - Dxietaf - ehess_val)

    # now check: rhess. Need to make sure xi, eta in the tangent space.
    # first compare this with numerical differentiation
    xi1 = man.proj(Y, xxi)
    eta1 = man.proj(Y, eeta)
    egvec = df(Y)
    ehvec = ehess_vec(Y, xi1)
    rhessvec = man.ehess2rhess(Y, egvec, ehvec, xi1)

    # check it numerically:
    def rgrad_func(Y):
        return man.proj_g_inv(Y, df(Y))

    val2, _, _ = calc_covar_numeric(man, Y, xi1, rgrad_func)
    val2_p = man.proj(Y, val2)
    # print(rhessvec)
    # print(val2_p)
    print(check_zero(rhessvec - val2_p))
    rhessval = man.inner(Y, rhessvec, eta1)
    print(man.inner(Y, val2, eta1))
    print(rhessval)

    # check symmetric:
    ehvec_e = ehess_vec(Y, eta1)
    ehess_valp = ehess_form(Y, xi1, eta1)

    rhessvec_e = man.ehess2rhess(Y, egvec, ehvec_e, eta1)
    rhessval_e = man.inner(Y, rhessvec_e, xi1)
    rhessval_e1 = man.rhess02(Y, xi1, eta1, egvec, ehess_valp)
    # rhessval_e2 = man.rhess02_alt(Y, xi1, eta1, egvec,
    #                              trace([email protected]()).real)
    # print(rhessval_e, rhessval_e1, rhessval_e2)
    print(rhessval_e, rhessval_e1, rhessval_e - rhessval_e1)

    print('rhessval_e %f ' % rhessval_e)

    # the above computed inner_prod(Nabla_xi Pi * df, eta)
    # in the following check. Extend eta1 to eta_proj
    # (Pi Nabla_hat Pi g_inv df, g eta)
    # = D_xi (Pi g_inv df, g eta) - (Pi g_inv df g Pi Nabla_hat eta)

    def eta_proj(Y):
        return man.proj(Y, eta_field(Y))

    print(check_zero(eta1 - eta_proj(Y)))

    e1 = man.inner(Y, man.proj_g_inv(Y, df(Y)), eta_proj(Y))
    e1a = trace(df(Y) @ eta_proj(Y).T.conjugate()).real
    print(e1, e1a, e1 - e1a)
    Ynew = Y + xi1 * dlt
    e2 = man.inner(Ynew, man.proj_g_inv(Ynew, df(Ynew)), eta_proj(Ynew))
    e2a = trace(df(Ynew) @ eta_proj(Ynew).T.conjugate()).real
    print(e2, e2a, e2 - e2a)

    first = (e2 - e1) / dlt
    first1 = trace(
        df(Ynew) @ eta_proj(Ynew).T.conjugate() -
        df(Y) @ eta_proj(Y).T.conjugate()).real / dlt
    print(first - first1)

    val3, _, _ = calc_covar_numeric(man, Y, xi1, eta_proj)
    second = man.inner(Y, man.proj_g_inv(Y, df(Y)), man.proj(Y, val3))
    second2 = man.inner(Y, man.proj_g_inv(Y, df(Y)), val3)
    print(second, second2, second - second2)
    print('same as rhess_val %f' % (first - second))
def test_all_projections():
    dvec = np.array([10, 3, 2, 3])
    p = dvec.shape[0] - 1
    alpha = randint(1, 10, (p, p + 1)) * .1
    man = ComplexFlag(dvec, alpha=alpha)
    Y = man.rand()
    U = man._rand_ambient()
    Upr = man.proj(Y, U)

    test_inner(man, Y)
    test_J(man, Y)

    # now check metric, Jst etc
    # check Jst: vectorize the operator J then compare Jst with jmat.T.conjugate()
    jmat = make_j_mat(man, Y)
    test_Jst(man, Y, jmat)
    ginv_mat = make_g_inv_mat(man, Y)
    # test g_inv_Jst
    for ii in range(10):
        a = man._rand_range_J()
        avec = man._vec_range_J(a)
        jtout = man._unvec(ginv_mat @ jmat.T @ avec)

        jtout2 = man.g_inv_Jst(Y, a)
        diff = check_zero(jtout - jtout2)
        print(diff)
    # test projection
    test_projection(man, Y)

    for i in range(20):
        Uran = man._rand_ambient()
        Upr = man.proj(Y, man.g_inv(Y, Uran))
        Upr2 = man.proj_g_inv(Y, Uran)
        print(check_zero(Upr - Upr2))

    for ii in range(10):
        a = man._rand_range_J()
        xi = man._rand_ambient()
        jtout2 = man.Jst(Y, a)
        dlt = 1e-7
        Ynew = Y + dlt * xi
        jtout2a = man.Jst(Ynew, a)
        d1 = (jtout2a - jtout2) / dlt
        d2 = man.D_Jst(Y, xi, a)
        print(check_zero(d2 - d1))

    for ii in range(10):
        Y = man.rand()
        eta = man._rand_ambient()
        xi = man.randvec(Y)
        a1 = man.J(Y, eta)
        dlt = 1e-7
        Ynew = Y + dlt * xi
        a2 = man.J(Ynew, eta)
        d1 = (man._vec_range_J(a2) - man._vec_range_J(a1)) / dlt
        d2 = man._vec_range_J(man.D_J(Y, xi, eta))
        print(check_zero(d2 - d1))

    for ii in range(10):
        a = man._rand_range_J()
        xi = man._rand_ambient()
        jtout2 = man.g_inv_Jst(Y, a)
        dlt = 1e-7
        Ynew = Y + dlt * xi
        jtout2a = man.g_inv_Jst(Ynew, a)
        d1 = (jtout2a - jtout2) / dlt
        d2 = man.D_g_inv_Jst(Y, xi, a)
        print(check_zero(d2 - d1))

    for ii in range(10):
        arand = man._rand_range_J()
        a2 = man.solve_J_g_inv_Jst(Y, arand)
        a1 = man.J(Y, man.g_inv_Jst(Y, a2))
        print(check_zero(man._vec_range_J(a1) - man._vec_range_J(arand)))

    # derives
    for ii in range(10):
        Y1 = man.rand()
        xi = man.randvec(Y1)
        omg1 = man._rand_ambient()
        omg2 = man._rand_ambient()
        dlt = 1e-7
        Y2 = Y1 + dlt * xi
        p1 = man.inner(Y1, omg1, omg2)
        p2 = man.inner(Y2, omg1, omg2)
        der1 = (p2 - p1) / dlt
        der2 = man.base_inner_ambient(man.D_g(Y1, xi, omg2), omg1)
        print(check_zero(der1 - der2))

    # cross term for christofel
    for i in range(10):
        Y1 = man.rand()
        xi = man.randvec(Y1)
        omg1 = man._rand_ambient()
        omg2 = man._rand_ambient()
        dr1 = man.D_g(Y1, xi, omg1)
        x12 = man.contract_D_g(Y1, omg1, omg2)

        p1 = trace(dr1 @ omg2.T.conjugate()).real
        p2 = trace(x12 @ xi.T.conjugate()).real
        print(p1, p2, p1 - p2)

    # now test christofel:
    # two things: symmetric on vector fields
    # and christofel relation
    # in the case metric
    for i in range(10):
        Y1 = man.rand()
        xi = man.randvec(Y1)
        eta1 = man.randvec(Y1)
        eta2 = man.randvec(Y1)
        p1 = man.proj_g_inv(Y1, man.christoffel_form(Y1, xi, eta1))
        p2 = man.proj_g_inv(Y1, man.christoffel_form(Y1, eta1, xi))
        print(check_zero(p1 - p2))
        v1 = man.base_inner_ambient(man.christoffel_form(Y1, eta1, eta2), xi)
        v2 = man.base_inner_ambient(man.D_g(Y1, eta1, eta2), xi)
        v3 = man.base_inner_ambient(man.D_g(Y1, eta2, eta1), xi)
        v4 = man.base_inner_ambient(man.D_g(Y1, xi, eta1), eta2)
        print(v1, 0.5 * (v2 + v3 - v4), v1 - 0.5 * (v2 + v3 - v4))
        """