def test_covariance_deriv():
    # now test full:
    # do covariant derivatives
    # check that it works, preseving everything
    dvec = np.array([10, 3, 2, 3])
    p = dvec.shape[0] - 1
    alpha = randint(1, 10, (p, p + 1)) * .1
    man = ComplexFlag(dvec, alpha=alpha)
    n = man.n
    d = man.d
    Y = man.rand()

    slp = np.random.randn(n * d)
    aa = np.random.randn(n * d, n * d)

    def omg_func(Y):
        return (aa @ Y.reshape(-1) + slp).reshape(n, d)

    xi = man.randvec(Y)

    egrad = omg_func(Y)
    ehess = (aa @ xi.reshape(-1)).reshape(n, d)

    val1 = man.ehess2rhess(Y, egrad, ehess, xi)

    def rgrad_func(W):
        return man.proj_g_inv(W, omg_func(W))

    if False:
        d_xi_rgrad = num_deriv(man, Y, xi, rgrad_func)
        rgrad = man.proj_g_inv(Y, egrad)
        fourth = man.christoffel_form(Y, xi, rgrad)
        val1c = man.proj(Y, d_xi_rgrad) + man.proj_g_inv(Y, fourth)

    if False:
        first = ehess
        a = man.J(Y, man.g_inv(Y, egrad))
        rgrad = man.proj_g_inv(Y, egrad)
        second = -man.D_g(Y, xi, man.g_inv(Y, egrad))
        aout = man.solve_J_g_inv_Jst(Y, a)
        third = -man.proj(Y, man.D_g_inv_Jst(Y, xi, aout))
        fourth = man.christoffel_form(Y, xi, rgrad)
        val1a = man.proj_g_inv(Y, first + second + fourth) + third

    d_xi_rgrad = num_deriv(man, Y, xi, rgrad_func)
    rgrad = man.proj_g_inv(Y, egrad)
    fourth = man.christoffel_form(Y, xi, rgrad)
    val1b = man.proj(Y, d_xi_rgrad) + man.proj_g_inv(Y, fourth)
    print(check_zero(val1 - val1b))
    # nabla_v_xi, dxi, cxxi
    # val2a, _, _ = calc_covar_numeric(man, Y, xi, omg_func)
    val2, _, _ = calc_covar_numeric(man, Y, xi, rgrad_func)
    # val2_p = project(prj, val2)
    val2_p = man.proj(Y, val2)
    # print(val1)
    # print(val2_p)
    print(check_zero(val1 - val2_p))
def test_christ_flat():
    """now test that christofel preserve metrics:
    on the flat space
    d_xi <v M v> = 2 <v M nabla_xi v>
     v = proj(W) @ (aa W + b)
    """
    dvec = np.array([10, 3, 2, 3])
    p = dvec.shape[0] - 1
    alpha = randint(1, 10, (p, p + 1)) * .1
    man = ComplexFlag(dvec, alpha=alpha)
    Y = man.rand()
    n = man.n
    d = man.d

    xi = man.randvec(Y)
    aa = crandn(n * d, n * d)
    bb = crandn(n * d)

    def v_func_flat(Y):
        return (aa @ Y.reshape(-1) + bb).reshape(n, d)

    vv = v_func_flat(Y)
    dlt = 1e-7
    Ynew = Y + dlt * xi
    vnew = v_func_flat(Ynew)

    val = man.inner(Y, vv, vv)
    valnew = man.inner(Ynew, vnew, vnew)
    d1 = (valnew - val) / dlt
    dv = (vnew - vv) / dlt
    nabla_xi_v = dv + man.g_inv(Y, man.christoffel_form(Y, xi, vv))
    d2 = man.inner(Y, vv, nabla_xi_v)

    print(d1)
    print(2 * d2)
def test_all_projections():
    dvec = np.array([10, 3, 2, 3])
    p = dvec.shape[0] - 1
    alpha = randint(1, 10, (p, p + 1)) * .1
    man = ComplexFlag(dvec, alpha=alpha)
    Y = man.rand()
    U = man._rand_ambient()
    Upr = man.proj(Y, U)

    test_inner(man, Y)
    test_J(man, Y)

    # now check metric, Jst etc
    # check Jst: vectorize the operator J then compare Jst with jmat.T.conjugate()
    jmat = make_j_mat(man, Y)
    test_Jst(man, Y, jmat)
    ginv_mat = make_g_inv_mat(man, Y)
    # test g_inv_Jst
    for ii in range(10):
        a = man._rand_range_J()
        avec = man._vec_range_J(a)
        jtout = man._unvec(ginv_mat @ jmat.T @ avec)

        jtout2 = man.g_inv_Jst(Y, a)
        diff = check_zero(jtout - jtout2)
        print(diff)
    # test projection
    test_projection(man, Y)

    for i in range(20):
        Uran = man._rand_ambient()
        Upr = man.proj(Y, man.g_inv(Y, Uran))
        Upr2 = man.proj_g_inv(Y, Uran)
        print(check_zero(Upr - Upr2))

    for ii in range(10):
        a = man._rand_range_J()
        xi = man._rand_ambient()
        jtout2 = man.Jst(Y, a)
        dlt = 1e-7
        Ynew = Y + dlt * xi
        jtout2a = man.Jst(Ynew, a)
        d1 = (jtout2a - jtout2) / dlt
        d2 = man.D_Jst(Y, xi, a)
        print(check_zero(d2 - d1))

    for ii in range(10):
        Y = man.rand()
        eta = man._rand_ambient()
        xi = man.randvec(Y)
        a1 = man.J(Y, eta)
        dlt = 1e-7
        Ynew = Y + dlt * xi
        a2 = man.J(Ynew, eta)
        d1 = (man._vec_range_J(a2) - man._vec_range_J(a1)) / dlt
        d2 = man._vec_range_J(man.D_J(Y, xi, eta))
        print(check_zero(d2 - d1))

    for ii in range(10):
        a = man._rand_range_J()
        xi = man._rand_ambient()
        jtout2 = man.g_inv_Jst(Y, a)
        dlt = 1e-7
        Ynew = Y + dlt * xi
        jtout2a = man.g_inv_Jst(Ynew, a)
        d1 = (jtout2a - jtout2) / dlt
        d2 = man.D_g_inv_Jst(Y, xi, a)
        print(check_zero(d2 - d1))

    for ii in range(10):
        arand = man._rand_range_J()
        a2 = man.solve_J_g_inv_Jst(Y, arand)
        a1 = man.J(Y, man.g_inv_Jst(Y, a2))
        print(check_zero(man._vec_range_J(a1) - man._vec_range_J(arand)))

    # derives
    for ii in range(10):
        Y1 = man.rand()
        xi = man.randvec(Y1)
        omg1 = man._rand_ambient()
        omg2 = man._rand_ambient()
        dlt = 1e-7
        Y2 = Y1 + dlt * xi
        p1 = man.inner(Y1, omg1, omg2)
        p2 = man.inner(Y2, omg1, omg2)
        der1 = (p2 - p1) / dlt
        der2 = man.base_inner_ambient(man.D_g(Y1, xi, omg2), omg1)
        print(check_zero(der1 - der2))

    # cross term for christofel
    for i in range(10):
        Y1 = man.rand()
        xi = man.randvec(Y1)
        omg1 = man._rand_ambient()
        omg2 = man._rand_ambient()
        dr1 = man.D_g(Y1, xi, omg1)
        x12 = man.contract_D_g(Y1, omg1, omg2)

        p1 = trace(dr1 @ omg2.T.conjugate()).real
        p2 = trace(x12 @ xi.T.conjugate()).real
        print(p1, p2, p1 - p2)

    # now test christofel:
    # two things: symmetric on vector fields
    # and christofel relation
    # in the case metric
    for i in range(10):
        Y1 = man.rand()
        xi = man.randvec(Y1)
        eta1 = man.randvec(Y1)
        eta2 = man.randvec(Y1)
        p1 = man.proj_g_inv(Y1, man.christoffel_form(Y1, xi, eta1))
        p2 = man.proj_g_inv(Y1, man.christoffel_form(Y1, eta1, xi))
        print(check_zero(p1 - p2))
        v1 = man.base_inner_ambient(man.christoffel_form(Y1, eta1, eta2), xi)
        v2 = man.base_inner_ambient(man.D_g(Y1, eta1, eta2), xi)
        v3 = man.base_inner_ambient(man.D_g(Y1, eta2, eta1), xi)
        v4 = man.base_inner_ambient(man.D_g(Y1, xi, eta1), eta2)
        print(v1, 0.5 * (v2 + v3 - v4), v1 - 0.5 * (v2 + v3 - v4))
        """