def energy_func_para_split_modSO3(f_f1, Coe_x, RX, f1, f2, Basis_1forms, a, b,
                                  c, d, Tstps):

    # rotate f_f2 using R
    R = torch_expm(RX)
    Rf2 = torch.einsum("ij,jmn->imn", [R, f2])  # Rotate f_f2

    df1 = f_to_df(f1)
    dRf2 = f_to_df(Rf2)

    Time_points = torch.arange(Tstps, out=torch.FloatTensor())
    # find the linear path between df1 and df2
    curve0 = df1 + torch.einsum("ijmn,t->tijmn", [dRf2 - df1, Time_points]) / (
        Tstps - 1)  # df1 + (df2-df1)t

    # calculate the new path between the new endpoints using the optimal coefficient matrix
    curve = torch.zeros(curve0.size())
    curve[0], curve[-1] = f_to_df(f_f1), curve0[-1]
    curve[1:Tstps - 1] = curve0[1:Tstps - 1] + torch.einsum(
        "lijmn,lt->tijmn", [Basis_1forms, Coe_x])

    d_curve = (curve[1:Tstps] - curve[0:Tstps - 1]) * (Tstps - 1)
    E = torch.zeros(Tstps - 1)
    for i in range(Tstps - 1):
        E[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d)
    return torch.sum(E) / (Tstps - 1)
Exemple #2
0
    def opt(X):
        X = torch.from_numpy(X).float().requires_grad_()
        R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty])

        gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1],
                                                    (1 - 1e-7) * R[2])))
        f1_gamma = compose_gamma(f1, gamma)

        q1_gamma = f_to_q(f1_gamma)
        q2 = f_to_q(f2)

        y = integral_over_s2(
            torch.einsum("imn,imn->mn", [q2 - q1_gamma, q2 - q1_gamma]))
        y.backward()
        return np.double(y.data.numpy()), np.double(X.grad.data.numpy())
Exemple #3
0
    def opt(X):
        X = torch.from_numpy(X).float().requires_grad_()
        R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty])

        gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1],
                                                    (1 - 1e-7) * R[2])))
        f1_gamma = compose_gamma(f1, gamma)

        df1_gamma = f_to_df(f1_gamma)
        df2 = f_to_df(f2)

        Time_points = torch.arange(Tstps, out=torch.FloatTensor())
        lin_path = df1_gamma + torch.einsum("ijmn,t->tijmn",
                                            [df2 - df1_gamma, Time_points]) / (
                                                Tstps - 1)  # df1 + (df2-df1)t
        y = length_dfunc_Imm(lin_path, a, b, c, d)
        y.backward()
        return np.double(y.data.numpy()), np.double(X.grad.data.numpy())
def energy_fun_combined_modSO3(f, X, Coe_x, RX, f1, f2, Tstps, Basis_vecFields,
                               Basis_1forms, a, b, c, d):
    # f1 and f2 are the original boundary surfaces

    # calculate R(f2)
    R = torch_expm(RX)
    Rf2 = torch.einsum("ij,jmn->imn", [R, f2])  # Rotate f2

    df1, dRf2 = f_to_df(f1), f_to_df(Rf2)

    # compute the linear path between df1 and df2
    Time_points = torch.arange(Tstps, out=torch.FloatTensor())
    curve0 = df1 + torch.einsum("ijmn,t->tijmn", [dRf2 - df1, Time_points]) / (
        Tstps - 1)  # df1 + (df2-df1)t

    # get the identity map from S2 to S2
    idty = get_idty_S2(*Basis_vecFields.shape[-2:])

    # calculate f(gamma)
    gamma = idty + torch.einsum("i, ijkl->jkl", [X, Basis_vecFields])
    gamma = gamma / torch.einsum("ijk, ijk->jk",
                                 [gamma, gamma]).sqrt()  # project to S2
    # get the spherical coordinate representation
    gammaSph = torch.stack((Cartesian_to_spherical(gamma[0] + 1e-7, gamma[1],
                                                   (1 - 1e-7) * gamma[2])))
    f_gamma = compose_gamma(f, gammaSph)

    # get a new curve that goes through all possibilities
    curve = torch.zeros(curve0.size())

    curve[0], curve[curve0.size(0) - 1] = f_to_df(f_gamma), curve0[-1]
    curve[1:curve0.size(0) - 1] = curve0[1:curve0.size(0) - 1] + torch.einsum(
        "lijmn,lt->tijmn", [Basis_1forms, Coe_x])
    d_curve = (curve[1:curve0.size(0)] -
               curve[0:curve0.size(0) - 1]) * (curve0.size(0) - 1)

    E = torch.zeros(curve0.size(0) - 1)
    for i in range(curve0.size(0) - 1):
        E[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d)
    return torch.sum(E) / (curve0.size(0) - 1)
def get_optCoe_shapes_split(f_f1,
                            Coe_x,
                            RX,
                            f1,
                            f2,
                            Basis_vecFields,
                            Basis_Sph,
                            Basis_1forms,
                            a,
                            b,
                            c,
                            d,
                            Tstps,
                            Max_ite,
                            side,
                            *,
                            multires=False):

    Num_basis = Basis_Sph.size(0)

    # define the function to be optimized
    CoeRX = torch.cat((Coe_x.flatten(), RX))

    def Opts_Imm(CoeRX):
        Coe_x = torch.from_numpy(CoeRX[:-3]).float().view(Num_basis,
                                                          -1).requires_grad_()
        RX = torch.from_numpy(CoeRX[-3:]).float().requires_grad_()

        y = energy_func_para_split_modSO3(f_f1, Coe_x, RX, f1, f2,
                                          Basis_1forms, a, b, c, d, Tstps)
        y.backward()

        CoeRX_grad = torch.cat((Coe_x.grad.flatten(), RX.grad))
        return np.double(y.data.numpy()), np.double(CoeRX_grad.data.numpy())

    # define callback to show the details for each iteration in the optimization process
    def printx(x):
        Energy.append(Opts_Imm(x)[0])

    Energy = []

    if multires == True:

        # load the bases for surfaces and exact 1forms
        mat_basis = sio.loadmat('Data/basis_exact_1forms_deg25_25_49.mat')
        Basis_1forms_low = torch.tensor(mat_basis['Basis'])[:Num_basis]

        m, n = 25, 49  # round(f1.shape[-2] / 2), int((f1.shape[-1] - 1)/2)
        f_f1_low = reduce_resolution(f_f1, m, n)
        f1_low = reduce_resolution(f1, m, n)
        f2_low = reduce_resolution(f2, m, n)

        def Opts_0(CoeRX):
            Coe_x = torch.from_numpy(CoeRX[:-3]).float().view(
                Num_basis, -1).requires_grad_()
            RX = torch.from_numpy(CoeRX[-3:]).float().requires_grad_()

            y0 = energy_func_para_split_modSO3(f_f1_low, Coe_x, RX, f1_low,
                                               f2_low, Basis_1forms_low, a, b,
                                               c, d, Tstps)
            y0.backward()

            CoeRX_grad = torch.cat((Coe_x.grad.flatten(), RX.grad))
            return np.double(y0.data.numpy()), np.double(
                CoeRX_grad.data.numpy())

        def printx0(x):
            Energy.append(Opts_0(x)[0])

        res0 = optimize.minimize(Opts_0,
                                 CoeRX,
                                 method='BFGS',
                                 jac=True,
                                 callback=printx0,
                                 options={
                                     'gtol': 1e-02,
                                     'disp': False,
                                     'maxiter': Max_ite
                                 })

        Coe_x_opt = torch.from_numpy(res0.x[:-3]).float().view(Num_basis, -1)
        RX_opt = torch.from_numpy(res0.x[-3:]).float()

    else:
        # set callback=printx to return the energy change in the optimization process, otherwise use None
        res = optimize.minimize(Opts_Imm,
                                CoeRX,
                                method='BFGS',
                                jac=True,
                                callback=printx,
                                options={
                                    'gtol': 1e-02,
                                    'disp': False,
                                    'maxiter': Max_ite
                                })
        # print(res.fun)
        Coe_x_opt = torch.from_numpy(res.x[:-3]).float().view(Num_basis, -1)
        RX_opt = torch.from_numpy(res.x[-3:]).float()

    # rotate f_f2 using R
    R = torch_expm(RX)
    Rf2 = torch.einsum("ij,jmn->imn", [R, f2])  # Rotate f_f2

    c_f = torch.zeros(2, 3, *f1.shape[-2:])
    c_f[0] = f_f1

    # compute the second and the last second discrete linear curve points
    lin_f = f1 + torch.einsum("imn,t->timn",
                              [Rf2 - f1, torch.tensor([1.])]) / (
                                  Tstps - 1)  # 1*3*Num_phi_Num_theta
    c_f[1] = lin_f[0] + torch.einsum("limn,l->imn",
                                     [Basis_Sph, Coe_x_opt[:, 0]])

    idty = get_idty_S2(*Basis_vecFields.shape[-2:])

    #     if side == 0:
    #         # 1 only update the first boundary surface
    c_f[0] = compute_optimal_reg_1st(c_f, idty, Basis_vecFields, a, b, c, d)
    #     elif side == 1:
    #         # 2 update both boundary surfaces
    #         c_f[0], c_f[-1] = compute_optimal_reg_both(c_f, idty, Basis_vecFields, a, b, c, d)

    return c_f[0], Coe_x_opt, RX_opt, Energy
Exemple #6
0
def initialize_over_paraSO3(f1, f2, idty, a, b, c, d):

    Tstps = 2

    # load the elements in the icosahedral group
    XIco_mat = sio.loadmat('Bases/skewIcosahedral.mat')
    XIco = torch.from_numpy(XIco_mat['X']).float()

    df2 = f_to_df(f2)
    Time_points = torch.arange(Tstps, out=torch.FloatTensor())

    linear_path = torch.zeros(60, Tstps, 3, 2, *f1.shape[-2:])
    f1_gammaIco = torch.zeros(60, *f1.size())
    df1_gammaIco = torch.zeros(60, 3, 2, *f1.shape[-2:])
    for i in range(60):
        RIco = torch.einsum("ij,jmn->imn", [torch_expm(XIco[i]), idty])
        gammaIco = torch.stack((Cartesian_to_spherical(RIco[0] + 1e-7, RIco[1],
                                                       (1 - 1e-7) * RIco[2])))
        f1_gammaIco[i] = compose_gamma(f1, gammaIco)
        df1_gammaIco[i] = f_to_df(f1_gammaIco[i])
        linear_path[i] = df1_gammaIco[i] + torch.einsum(
            "ijmn,t->tijmn", [df2 - df1_gammaIco[i], Time_points]) / (
                Tstps - 1)  # df1 + (df2-df1)t

    length_linear = torch.zeros(60)
    for i in range(60):
        length_linear[i] = length_dfunc_Imm(linear_path[i], a, b, c, d)

    # get the index of the smallest value
    Ind = np.argmin(length_linear)

    X = XIco[Ind]

    L = []

    def opt(X):
        X = torch.from_numpy(X).float().requires_grad_()
        R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty])

        gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1],
                                                    (1 - 1e-7) * R[2])))
        f1_gamma = compose_gamma(f1, gamma)

        df1_gamma = f_to_df(f1_gamma)
        df2 = f_to_df(f2)

        Time_points = torch.arange(Tstps, out=torch.FloatTensor())
        lin_path = df1_gamma + torch.einsum("ijmn,t->tijmn",
                                            [df2 - df1_gamma, Time_points]) / (
                                                Tstps - 1)  # df1 + (df2-df1)t
        y = length_dfunc_Imm(lin_path, a, b, c, d)
        y.backward()
        return np.double(y.data.numpy()), np.double(X.grad.data.numpy())

    def printx(x):
        L.append(opt(x)[0])

    res = optimize.minimize(opt,
                            X,
                            method='BFGS',
                            jac=True,
                            callback=printx,
                            options={
                                'gtol': 1e-02,
                                'disp': False
                            })  # True

    X_opt = torch.from_numpy(res.x).float()
    R_opt = torch.einsum("ij,jmn->imn", [torch_expm(X_opt), idty])
    gamma_opt = torch.stack((Cartesian_to_spherical(R_opt[0] + 1e-7, R_opt[1],
                                                    (1 - 1e-7) * R_opt[2])))
    f1_gamma = compose_gamma(f1, gamma_opt)
    return f1_gamma, L, f1_gammaIco[Ind], length_linear
def compute_geodesic_unparaModSO3(f1, f2, *, MaxDegVecFS2, MaxDegHarmSurf,
                           Cmetric=(), Tpts, method='split', maxiter=(10, 30), **kwargs):

    # load the bases for surfaces and exact 1forms
    mat_basis = sio.loadmat('Bases/basis_exact_1forms_deg25_{0}_{1}.mat'.format(*f1.shape[-2:]))

    Num_basis = ((MaxDegHarmSurf + 1) ** 2 - 1) * 3  # the number of basis for 1forms
    Basis_Sph = torch.from_numpy(mat_basis['Basis_Sph'])[: Num_basis].float()
    Basis_1forms = torch.from_numpy(mat_basis['Basis'])[: Num_basis].float()

    # load the basis for tangent vector fields on S2
    mat_vecF = sio.loadmat('Bases/basis_vecFieldsS2_deg25_{0}_{1}.mat'.format(*f1.shape[-2:]))

    N_basis_vec = (MaxDegVecFS2 + 1) ** 2 - 1  # half the number of basis for the vector fields on S2
    Basis0_vec = torch.from_numpy(mat_vecF['Basis'])[: N_basis_vec].float()
    Basis_vecFields = torch.cat((Basis0_vec[:, 0], Basis0_vec[:, 1]))  # get a basis of the tangent fields on S2

    # get the coefficients for thw split metric
    a, b, c, d = Cmetric
    
    # compute the geodesic 
    EnergyAll = []
    
    # set the number of iteration for the whole algorithm
    N_ite, Max_ite_in = maxiter
    
    f1_new = f1

    if method == 'split':
        
        f_f1, f_f2 = f1_new, f2
         
        side = 0 
        
        Tpts0 = Tpts
        
        if kwargs.get('multires'):
            multires = True
            
            # multresolution in time
            Tpts_low = 5   
            Tpts0 = Tpts_low
            
        else: 
            multires = False
        
        Coe_x = torch.zeros(Basis_Sph.size(0), Tpts0 - 2)
        RX = torch.zeros(3)
        
        for i in range(N_ite):
            
            if i == N_ite-1:
                Max_ite_in = 100
                if Tpts0 != Tpts:
                    Tpts0 = Tpts
                    Coe_x = up_sample(Coe_x, Tpts - 2)
                
            if i > round(int(N_ite/2)):
                multires = False
     
            f_f1, Coe_x, RX, Energy = get_optCoe_shapes_split(f_f1, Coe_x, RX, f1_new, f2, Basis_vecFields, 
                                                                Basis_Sph, Basis_1forms, a, b, c, d, Tpts0,
                                                                Max_ite_in, side, **{'multires': multires})
            EnergyAll.append(Energy)
        
        # rotate f2 using R
        R = torch_expm(RX)
        Rf2 = torch.einsum("ij,jmn->imn",[R,f2]) # Rotate f_f2
    
        Time_points = torch.arange(Tpts, out=torch.FloatTensor())
        lin_f = f1_new + torch.einsum("imn,t->timn", [Rf2-f1_new, Time_points])/(Tpts-1)
        
        # perturbe the linear path using the optimal coefficients
        geo_f = torch.zeros(Tpts, 3, *f1.shape[-2:])
        geo_f[0], geo_f[-1] = f_f1, Rf2
        geo_f[1: Tpts - 1] = lin_f[1: Tpts - 1] + torch.einsum("limn,lt->timn", [Basis_Sph, Coe_x])
        
    elif method == 'combined':
        
        f = f1_new
        Coe_x = torch.zeros(Basis_Sph.size(0), Tpts - 2)
        RX = torch.zeros(3)

        idty = get_idty_S2(*Basis_vecFields.shape[-2:])
        
        for i in range(N_ite):
            
            if i == N_ite-1:
                Max_ite_in = 100
    
            X_new, Coe_x, RX, Energy = get_optCoe_shapes_combined(f, Coe_x, RX, f1_new, f2, Basis_vecFields,
                                                               Basis_1forms, a, b, c, d, Tpts, Max_ite_in)
    
            # update f
            gamma = idty + torch.einsum("i, ijkl->jkl", [X_new, Basis_vecFields])
            gamma = gamma/torch.einsum("ijk, ijk->jk", [gamma, gamma]).sqrt()  # project to S2
            # get the spherical coordinate representation
            gammaSph = torch.stack((Cartesian_to_spherical(gamma[0], gamma[1], gamma[2])))

            f = compose_gamma(f, gammaSph)
            EnergyAll.append(Energy)
        
        # get R(f2)
        
        R = torch_expm(RX)
        Rf2 = torch.einsum("ij,jmn->imn",[R,f2])
        
        Time_points = torch.arange(Tpts, out=torch.FloatTensor())
        lin_f = f1_new + torch.einsum("imn,t->timn", [Rf2-f1_new, Time_points])/(Tpts-1)
        
        # perturbe the linear path using the optimal coefficients
        geo_f = torch.zeros(Tpts, 3, *f1.shape[-2:])
        geo_f[0], geo_f[Tpts - 1] = f, lin_f[-1]
        geo_f[1: Tpts - 1] = lin_f[1: Tpts - 1] + torch.einsum("limn,lt->timn", [Basis_Sph, Coe_x])
        
    
    EnergyAll0 = [item for sublist in EnergyAll for item in sublist]
    return geo_f, EnergyAll0
Exemple #8
0
def initialize_over_paraSO3_SRNF(f1, f2, idty):

    # load the elements in the icosahedral group
    XIco_mat = sio.loadmat('Bases/skewIcosahedral.mat')
    XIco = torch.from_numpy(XIco_mat['X']).float()

    q2 = f_to_q(f2)

    EIco = torch.zeros(60)
    f1_gammaIco = torch.zeros(60, *f1.size())
    q1_gammaIco = torch.zeros(60, *f1.size())
    for i in range(60):
        RIco = torch.einsum("ij,jmn->imn", [torch_expm(XIco[i]), idty])
        gammaIco = torch.stack((Cartesian_to_spherical(RIco[0] + 1e-7, RIco[1],
                                                       (1 - 1e-7) * RIco[2])))
        f1_gammaIco[i] = compose_gamma(f1, gammaIco)
        q1_gammaIco[i] = f_to_q(f1_gammaIco[i])

        EIco[i] = integral_over_s2(
            torch.einsum("imn,imn->mn",
                         [q2 - q1_gammaIco[i], q2 - q1_gammaIco[i]]))

    # get the index of the smallest value
    Ind = np.argmin(EIco)

    X = XIco[Ind]

    L2_ESO3 = []

    def opt(X):
        X = torch.from_numpy(X).float().requires_grad_()
        R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty])

        gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1],
                                                    (1 - 1e-7) * R[2])))
        f1_gamma = compose_gamma(f1, gamma)

        q1_gamma = f_to_q(f1_gamma)
        q2 = f_to_q(f2)

        y = integral_over_s2(
            torch.einsum("imn,imn->mn", [q2 - q1_gamma, q2 - q1_gamma]))
        y.backward()
        return np.double(y.data.numpy()), np.double(X.grad.data.numpy())

    def printx(x):
        L2_ESO3.append(opt(x)[0])

    res = optimize.minimize(opt,
                            X,
                            method='BFGS',
                            jac=True,
                            callback=printx,
                            options={
                                'gtol': 1e-02,
                                'disp': False
                            })  # True

    X_opt = torch.from_numpy(res.x).float()
    R_opt = torch.einsum("ij,jmn->imn", [torch_expm(X_opt), idty])
    gamma_opt = torch.stack((Cartesian_to_spherical(R_opt[0] + 1e-7, R_opt[1],
                                                    (1 - 1e-7) * R_opt[2])))
    f1_gamma = compose_gamma(f1, gamma_opt)
    return f1_gamma, L2_ESO3, f1_gammaIco[Ind], EIco