def energy_func_para_split_modSO3(f_f1, Coe_x, RX, f1, f2, Basis_1forms, a, b,
                                  c, d, Tstps):

    # rotate f_f2 using R
    R = torch_expm(RX)
    Rf2 = torch.einsum("ij,jmn->imn", [R, f2])  # Rotate f_f2

    df1 = f_to_df(f1)
    dRf2 = f_to_df(Rf2)

    Time_points = torch.arange(Tstps, out=torch.FloatTensor())
    # find the linear path between df1 and df2
    curve0 = df1 + torch.einsum("ijmn,t->tijmn", [dRf2 - df1, Time_points]) / (
        Tstps - 1)  # df1 + (df2-df1)t

    # calculate the new path between the new endpoints using the optimal coefficient matrix
    curve = torch.zeros(curve0.size())
    curve[0], curve[-1] = f_to_df(f_f1), curve0[-1]
    curve[1:Tstps - 1] = curve0[1:Tstps - 1] + torch.einsum(
        "lijmn,lt->tijmn", [Basis_1forms, Coe_x])

    d_curve = (curve[1:Tstps] - curve[0:Tstps - 1]) * (Tstps - 1)
    E = torch.zeros(Tstps - 1)
    for i in range(Tstps - 1):
        E[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d)
    return torch.sum(E) / (Tstps - 1)
def compute_optCoe_Imm(Coe_x, f1, f2, a, b, c, d, Basis_1forms, Tpts, Max_ite):

    Num_basis = Basis_1forms.size(0)

    df1 = f_to_df(f1)
    df2 = f_to_df(f2)

    Time_points = torch.arange(Tpts, out=torch.FloatTensor())

    # find the linear path between df1 and df2
    curve0 = df1 + torch.einsum("ijmn,t->tijmn", [df2 - df1, Time_points]) / (Tpts - 1)  # df1 + (df2-df1)t

    # define the function to be optimized
    def Opts_Imm(Coe_x):
        Coe_x = torch.from_numpy(Coe_x).float()
        Coe_x = Coe_x.view(Num_basis, -1)
        y = energy_func_Imm(Coe_x, curve0, Basis_1forms, a, b, c, d)
        return np.double(, np.double(

    # define callback to show the details for each iteration in the optimization process
    def printx(x):
    AllEnergy = []
    res = optimize.minimize(Opts_Imm, Coe_x.flatten(), method='BFGS', jac=True, callback=printx,
                            options={'gtol': 1e-02, 'disp': False, 'maxiter': Max_ite})
    coe_x_opt = torch.from_numpy(res.x).float().view(Num_basis, -1)

    return coe_x_opt, AllEnergy
def energy_reg(X, c_f, idty, Basis_vecFields, a, b, c, d):
    f1, f2 = c_f[0], c_f[1]
    gamma = idty + torch.einsum("i, ijkl->jkl", [X, Basis_vecFields])
    gamma = gamma / torch.einsum("ijk, ijk->jk",
                                 [gamma, gamma]).sqrt()  # project to S2

    # get the spherical coordinate representation
    gammaSph = torch.stack((Cartesian_to_spherical(gamma[0] + 1e-7, gamma[1],
                                                   (1 - 1e-7) * gamma[2])))
    f1_gamma = compose_gamma(f1, gammaSph)

    df1_gamma = f_to_df(f1_gamma)
    Diff_12 = f_to_df(f2 - f1_gamma)
    E = riemann_metric(df1_gamma, Diff_12, Diff_12, a, b, c, d)  # /dT
    return E
    def opt(X):
        X = torch.from_numpy(X).float().requires_grad_()
        R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty])

        gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1],
                                                    (1 - 1e-7) * R[2])))
        f1_gamma = compose_gamma(f1, gamma)

        df1_gamma = f_to_df(f1_gamma)
        df2 = f_to_df(f2)

        Time_points = torch.arange(Tstps, out=torch.FloatTensor())
        lin_path = df1_gamma + torch.einsum("ijmn,t->tijmn",
                                            [df2 - df1_gamma, Time_points]) / (
                                                Tstps - 1)  # df1 + (df2-df1)t
        y = length_dfunc_Imm(lin_path, a, b, c, d)
        return np.double(, np.double(
def length_func_surf_Imm(curve_f, a, b, c, d):  # size of curve: T*3*2*m*n
    curve = torch.zeros(curve_f.size(0), 3, 2, *curve_f.shape[-2:])
    for i in range(curve_f.size(0)):
        curve[i] = f_to_df(curve_f[i])
    d_curve = (curve[1:curve.size(0)] -
               curve[0:curve.size(0) - 1]) * (curve.size(0) - 1)
    L = torch.zeros(curve.size(0) - 1)
    for i in range(curve.size(0) - 1):
        L[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c,
    return torch.sum(L) / (curve.size(0) - 1)
def energy_fun_combined_modSO3(f, X, Coe_x, RX, f1, f2, Tstps, Basis_vecFields,
                               Basis_1forms, a, b, c, d):
    # f1 and f2 are the original boundary surfaces

    # calculate R(f2)
    R = torch_expm(RX)
    Rf2 = torch.einsum("ij,jmn->imn", [R, f2])  # Rotate f2

    df1, dRf2 = f_to_df(f1), f_to_df(Rf2)

    # compute the linear path between df1 and df2
    Time_points = torch.arange(Tstps, out=torch.FloatTensor())
    curve0 = df1 + torch.einsum("ijmn,t->tijmn", [dRf2 - df1, Time_points]) / (
        Tstps - 1)  # df1 + (df2-df1)t

    # get the identity map from S2 to S2
    idty = get_idty_S2(*Basis_vecFields.shape[-2:])

    # calculate f(gamma)
    gamma = idty + torch.einsum("i, ijkl->jkl", [X, Basis_vecFields])
    gamma = gamma / torch.einsum("ijk, ijk->jk",
                                 [gamma, gamma]).sqrt()  # project to S2
    # get the spherical coordinate representation
    gammaSph = torch.stack((Cartesian_to_spherical(gamma[0] + 1e-7, gamma[1],
                                                   (1 - 1e-7) * gamma[2])))
    f_gamma = compose_gamma(f, gammaSph)

    # get a new curve that goes through all possibilities
    curve = torch.zeros(curve0.size())

    curve[0], curve[curve0.size(0) - 1] = f_to_df(f_gamma), curve0[-1]
    curve[1:curve0.size(0) - 1] = curve0[1:curve0.size(0) - 1] + torch.einsum(
        "lijmn,lt->tijmn", [Basis_1forms, Coe_x])
    d_curve = (curve[1:curve0.size(0)] -
               curve[0:curve0.size(0) - 1]) * (curve0.size(0) - 1)

    E = torch.zeros(curve0.size(0) - 1)
    for i in range(curve0.size(0) - 1):
        E[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d)
    return torch.sum(E) / (curve0.size(0) - 1)
def initialize_overSO3(f1, f2):

    df1 = f_to_df(f1)
    df2 = f_to_df(f2)

    n1 = oneForm_to_n(df1).view(3, -1)
    n1 = n1 - torch.sum(n1, 1).view(3, -1).expand(-1, n1.size(1)) / n1.size(1)
    n2 = oneForm_to_n(df2).view(3, -1)
    n2 = n2 - torch.sum(n2, 1).view(3, -1).expand(-1, n2.size(1)) / n2.size(1)
    #     print(torch.sum(n1, 1)/n1.size(1), torch.sum(n2, 1)/n1.size(1))

    Rz = torch.tensor([[-1, 0., 0], [0, -1, 0], [0, 0,
                                                 1]])  # z fixed rotate xy
    Rx = torch.tensor([[1, 0., 0], [0, -1, 0], [0, 0,
                                                -1]])  # x fixed rotate yz
    Ry = torch.tensor([[-1, 0., 0], [0, 1, 0], [0, 0,
                                                -1]])  # y fixed rotate xz

    n10, n20 = n1.view(3, -1), n2.view(3, -1)
    U1, D1, V1 = torch.svd(n10)
    U2, D2, V2 = torch.svd(n20)

    D0 = torch.eye(3)
    D0[2, 2] = torch.sign(torch.det(U1 @ torch.inverse(U2)))
    R = U1 @ D0 @ torch.inverse(U2)

    allR = torch.stack((R, Rz @ R, Rx @ R, Ry @ R))
    allf2 = torch.einsum("lij,jmn->limn", [allR, f2])

    L = np.zeros(4)
    for i in range(4):
        Diff = f1 - allf2[i]
        L[i] = torch.norm(Diff.flatten()).numpy()

    Ind = np.argmin(L)
    return allf2[Ind], L
def initialize_over_paraSO3(f1, f2, idty, a, b, c, d):

    Tstps = 2

    # load the elements in the icosahedral group
    XIco_mat = sio.loadmat('Bases/skewIcosahedral.mat')
    XIco = torch.from_numpy(XIco_mat['X']).float()

    df2 = f_to_df(f2)
    Time_points = torch.arange(Tstps, out=torch.FloatTensor())

    linear_path = torch.zeros(60, Tstps, 3, 2, *f1.shape[-2:])
    f1_gammaIco = torch.zeros(60, *f1.size())
    df1_gammaIco = torch.zeros(60, 3, 2, *f1.shape[-2:])
    for i in range(60):
        RIco = torch.einsum("ij,jmn->imn", [torch_expm(XIco[i]), idty])
        gammaIco = torch.stack((Cartesian_to_spherical(RIco[0] + 1e-7, RIco[1],
                                                       (1 - 1e-7) * RIco[2])))
        f1_gammaIco[i] = compose_gamma(f1, gammaIco)
        df1_gammaIco[i] = f_to_df(f1_gammaIco[i])
        linear_path[i] = df1_gammaIco[i] + torch.einsum(
            "ijmn,t->tijmn", [df2 - df1_gammaIco[i], Time_points]) / (
                Tstps - 1)  # df1 + (df2-df1)t

    length_linear = torch.zeros(60)
    for i in range(60):
        length_linear[i] = length_dfunc_Imm(linear_path[i], a, b, c, d)

    # get the index of the smallest value
    Ind = np.argmin(length_linear)

    X = XIco[Ind]

    L = []

    def opt(X):
        X = torch.from_numpy(X).float().requires_grad_()
        R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty])

        gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1],
                                                    (1 - 1e-7) * R[2])))
        f1_gamma = compose_gamma(f1, gamma)

        df1_gamma = f_to_df(f1_gamma)
        df2 = f_to_df(f2)

        Time_points = torch.arange(Tstps, out=torch.FloatTensor())
        lin_path = df1_gamma + torch.einsum("ijmn,t->tijmn",
                                            [df2 - df1_gamma, Time_points]) / (
                                                Tstps - 1)  # df1 + (df2-df1)t
        y = length_dfunc_Imm(lin_path, a, b, c, d)
        return np.double(, np.double(

    def printx(x):

    res = optimize.minimize(opt,
                                'gtol': 1e-02,
                                'disp': False
                            })  # True

    X_opt = torch.from_numpy(res.x).float()
    R_opt = torch.einsum("ij,jmn->imn", [torch_expm(X_opt), idty])
    gamma_opt = torch.stack((Cartesian_to_spherical(R_opt[0] + 1e-7, R_opt[1],
                                                    (1 - 1e-7) * R_opt[2])))
    f1_gamma = compose_gamma(f1, gamma_opt)
    return f1_gamma, L, f1_gammaIco[Ind], length_linear
def rescale_surf(f):
    df = f_to_df(f)
    n_f = torch.cross(df[:, 0], df[:, 1], dim=0)
    Norm_n = torch.einsum("imn,imn->mn", [n_f, n_f]).sqrt()

    return f / integral_over_s2(Norm_n).sqrt()