def energy_func_para_split_modSO3(f_f1, Coe_x, RX, f1, f2, Basis_1forms, a, b, c, d, Tstps): # rotate f_f2 using R R = torch_expm(RX) Rf2 = torch.einsum("ij,jmn->imn", [R, f2]) # Rotate f_f2 df1 = f_to_df(f1) dRf2 = f_to_df(Rf2) Time_points = torch.arange(Tstps, out=torch.FloatTensor()) # find the linear path between df1 and df2 curve0 = df1 + torch.einsum("ijmn,t->tijmn", [dRf2 - df1, Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t # calculate the new path between the new endpoints using the optimal coefficient matrix curve = torch.zeros(curve0.size()) curve[0], curve[-1] = f_to_df(f_f1), curve0[-1] curve[1:Tstps - 1] = curve0[1:Tstps - 1] + torch.einsum( "lijmn,lt->tijmn", [Basis_1forms, Coe_x]) d_curve = (curve[1:Tstps] - curve[0:Tstps - 1]) * (Tstps - 1) E = torch.zeros(Tstps - 1) for i in range(Tstps - 1): E[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d) return torch.sum(E) / (Tstps - 1)
def opt(X): X = torch.from_numpy(X).float().requires_grad_() R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty]) gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1], (1 - 1e-7) * R[2]))) f1_gamma = compose_gamma(f1, gamma) q1_gamma = f_to_q(f1_gamma) q2 = f_to_q(f2) y = integral_over_s2( torch.einsum("imn,imn->mn", [q2 - q1_gamma, q2 - q1_gamma])) y.backward() return np.double(y.data.numpy()), np.double(X.grad.data.numpy())
def opt(X): X = torch.from_numpy(X).float().requires_grad_() R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty]) gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1], (1 - 1e-7) * R[2]))) f1_gamma = compose_gamma(f1, gamma) df1_gamma = f_to_df(f1_gamma) df2 = f_to_df(f2) Time_points = torch.arange(Tstps, out=torch.FloatTensor()) lin_path = df1_gamma + torch.einsum("ijmn,t->tijmn", [df2 - df1_gamma, Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t y = length_dfunc_Imm(lin_path, a, b, c, d) y.backward() return np.double(y.data.numpy()), np.double(X.grad.data.numpy())
def energy_fun_combined_modSO3(f, X, Coe_x, RX, f1, f2, Tstps, Basis_vecFields, Basis_1forms, a, b, c, d): # f1 and f2 are the original boundary surfaces # calculate R(f2) R = torch_expm(RX) Rf2 = torch.einsum("ij,jmn->imn", [R, f2]) # Rotate f2 df1, dRf2 = f_to_df(f1), f_to_df(Rf2) # compute the linear path between df1 and df2 Time_points = torch.arange(Tstps, out=torch.FloatTensor()) curve0 = df1 + torch.einsum("ijmn,t->tijmn", [dRf2 - df1, Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t # get the identity map from S2 to S2 idty = get_idty_S2(*Basis_vecFields.shape[-2:]) # calculate f(gamma) gamma = idty + torch.einsum("i, ijkl->jkl", [X, Basis_vecFields]) gamma = gamma / torch.einsum("ijk, ijk->jk", [gamma, gamma]).sqrt() # project to S2 # get the spherical coordinate representation gammaSph = torch.stack((Cartesian_to_spherical(gamma[0] + 1e-7, gamma[1], (1 - 1e-7) * gamma[2]))) f_gamma = compose_gamma(f, gammaSph) # get a new curve that goes through all possibilities curve = torch.zeros(curve0.size()) curve[0], curve[curve0.size(0) - 1] = f_to_df(f_gamma), curve0[-1] curve[1:curve0.size(0) - 1] = curve0[1:curve0.size(0) - 1] + torch.einsum( "lijmn,lt->tijmn", [Basis_1forms, Coe_x]) d_curve = (curve[1:curve0.size(0)] - curve[0:curve0.size(0) - 1]) * (curve0.size(0) - 1) E = torch.zeros(curve0.size(0) - 1) for i in range(curve0.size(0) - 1): E[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d) return torch.sum(E) / (curve0.size(0) - 1)
def get_optCoe_shapes_split(f_f1, Coe_x, RX, f1, f2, Basis_vecFields, Basis_Sph, Basis_1forms, a, b, c, d, Tstps, Max_ite, side, *, multires=False): Num_basis = Basis_Sph.size(0) # define the function to be optimized CoeRX = torch.cat((Coe_x.flatten(), RX)) def Opts_Imm(CoeRX): Coe_x = torch.from_numpy(CoeRX[:-3]).float().view(Num_basis, -1).requires_grad_() RX = torch.from_numpy(CoeRX[-3:]).float().requires_grad_() y = energy_func_para_split_modSO3(f_f1, Coe_x, RX, f1, f2, Basis_1forms, a, b, c, d, Tstps) y.backward() CoeRX_grad = torch.cat((Coe_x.grad.flatten(), RX.grad)) return np.double(y.data.numpy()), np.double(CoeRX_grad.data.numpy()) # define callback to show the details for each iteration in the optimization process def printx(x): Energy.append(Opts_Imm(x)[0]) Energy = [] if multires == True: # load the bases for surfaces and exact 1forms mat_basis = sio.loadmat('Data/basis_exact_1forms_deg25_25_49.mat') Basis_1forms_low = torch.tensor(mat_basis['Basis'])[:Num_basis] m, n = 25, 49 # round(f1.shape[-2] / 2), int((f1.shape[-1] - 1)/2) f_f1_low = reduce_resolution(f_f1, m, n) f1_low = reduce_resolution(f1, m, n) f2_low = reduce_resolution(f2, m, n) def Opts_0(CoeRX): Coe_x = torch.from_numpy(CoeRX[:-3]).float().view( Num_basis, -1).requires_grad_() RX = torch.from_numpy(CoeRX[-3:]).float().requires_grad_() y0 = energy_func_para_split_modSO3(f_f1_low, Coe_x, RX, f1_low, f2_low, Basis_1forms_low, a, b, c, d, Tstps) y0.backward() CoeRX_grad = torch.cat((Coe_x.grad.flatten(), RX.grad)) return np.double(y0.data.numpy()), np.double( CoeRX_grad.data.numpy()) def printx0(x): Energy.append(Opts_0(x)[0]) res0 = optimize.minimize(Opts_0, CoeRX, method='BFGS', jac=True, callback=printx0, options={ 'gtol': 1e-02, 'disp': False, 'maxiter': Max_ite }) Coe_x_opt = torch.from_numpy(res0.x[:-3]).float().view(Num_basis, -1) RX_opt = torch.from_numpy(res0.x[-3:]).float() else: # set callback=printx to return the energy change in the optimization process, otherwise use None res = optimize.minimize(Opts_Imm, CoeRX, method='BFGS', jac=True, callback=printx, options={ 'gtol': 1e-02, 'disp': False, 'maxiter': Max_ite }) # print(res.fun) Coe_x_opt = torch.from_numpy(res.x[:-3]).float().view(Num_basis, -1) RX_opt = torch.from_numpy(res.x[-3:]).float() # rotate f_f2 using R R = torch_expm(RX) Rf2 = torch.einsum("ij,jmn->imn", [R, f2]) # Rotate f_f2 c_f = torch.zeros(2, 3, *f1.shape[-2:]) c_f[0] = f_f1 # compute the second and the last second discrete linear curve points lin_f = f1 + torch.einsum("imn,t->timn", [Rf2 - f1, torch.tensor([1.])]) / ( Tstps - 1) # 1*3*Num_phi_Num_theta c_f[1] = lin_f[0] + torch.einsum("limn,l->imn", [Basis_Sph, Coe_x_opt[:, 0]]) idty = get_idty_S2(*Basis_vecFields.shape[-2:]) # if side == 0: # # 1 only update the first boundary surface c_f[0] = compute_optimal_reg_1st(c_f, idty, Basis_vecFields, a, b, c, d) # elif side == 1: # # 2 update both boundary surfaces # c_f[0], c_f[-1] = compute_optimal_reg_both(c_f, idty, Basis_vecFields, a, b, c, d) return c_f[0], Coe_x_opt, RX_opt, Energy
def initialize_over_paraSO3(f1, f2, idty, a, b, c, d): Tstps = 2 # load the elements in the icosahedral group XIco_mat = sio.loadmat('Bases/skewIcosahedral.mat') XIco = torch.from_numpy(XIco_mat['X']).float() df2 = f_to_df(f2) Time_points = torch.arange(Tstps, out=torch.FloatTensor()) linear_path = torch.zeros(60, Tstps, 3, 2, *f1.shape[-2:]) f1_gammaIco = torch.zeros(60, *f1.size()) df1_gammaIco = torch.zeros(60, 3, 2, *f1.shape[-2:]) for i in range(60): RIco = torch.einsum("ij,jmn->imn", [torch_expm(XIco[i]), idty]) gammaIco = torch.stack((Cartesian_to_spherical(RIco[0] + 1e-7, RIco[1], (1 - 1e-7) * RIco[2]))) f1_gammaIco[i] = compose_gamma(f1, gammaIco) df1_gammaIco[i] = f_to_df(f1_gammaIco[i]) linear_path[i] = df1_gammaIco[i] + torch.einsum( "ijmn,t->tijmn", [df2 - df1_gammaIco[i], Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t length_linear = torch.zeros(60) for i in range(60): length_linear[i] = length_dfunc_Imm(linear_path[i], a, b, c, d) # get the index of the smallest value Ind = np.argmin(length_linear) X = XIco[Ind] L = [] def opt(X): X = torch.from_numpy(X).float().requires_grad_() R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty]) gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1], (1 - 1e-7) * R[2]))) f1_gamma = compose_gamma(f1, gamma) df1_gamma = f_to_df(f1_gamma) df2 = f_to_df(f2) Time_points = torch.arange(Tstps, out=torch.FloatTensor()) lin_path = df1_gamma + torch.einsum("ijmn,t->tijmn", [df2 - df1_gamma, Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t y = length_dfunc_Imm(lin_path, a, b, c, d) y.backward() return np.double(y.data.numpy()), np.double(X.grad.data.numpy()) def printx(x): L.append(opt(x)[0]) res = optimize.minimize(opt, X, method='BFGS', jac=True, callback=printx, options={ 'gtol': 1e-02, 'disp': False }) # True X_opt = torch.from_numpy(res.x).float() R_opt = torch.einsum("ij,jmn->imn", [torch_expm(X_opt), idty]) gamma_opt = torch.stack((Cartesian_to_spherical(R_opt[0] + 1e-7, R_opt[1], (1 - 1e-7) * R_opt[2]))) f1_gamma = compose_gamma(f1, gamma_opt) return f1_gamma, L, f1_gammaIco[Ind], length_linear
def compute_geodesic_unparaModSO3(f1, f2, *, MaxDegVecFS2, MaxDegHarmSurf, Cmetric=(), Tpts, method='split', maxiter=(10, 30), **kwargs): # load the bases for surfaces and exact 1forms mat_basis = sio.loadmat('Bases/basis_exact_1forms_deg25_{0}_{1}.mat'.format(*f1.shape[-2:])) Num_basis = ((MaxDegHarmSurf + 1) ** 2 - 1) * 3 # the number of basis for 1forms Basis_Sph = torch.from_numpy(mat_basis['Basis_Sph'])[: Num_basis].float() Basis_1forms = torch.from_numpy(mat_basis['Basis'])[: Num_basis].float() # load the basis for tangent vector fields on S2 mat_vecF = sio.loadmat('Bases/basis_vecFieldsS2_deg25_{0}_{1}.mat'.format(*f1.shape[-2:])) N_basis_vec = (MaxDegVecFS2 + 1) ** 2 - 1 # half the number of basis for the vector fields on S2 Basis0_vec = torch.from_numpy(mat_vecF['Basis'])[: N_basis_vec].float() Basis_vecFields = torch.cat((Basis0_vec[:, 0], Basis0_vec[:, 1])) # get a basis of the tangent fields on S2 # get the coefficients for thw split metric a, b, c, d = Cmetric # compute the geodesic EnergyAll = [] # set the number of iteration for the whole algorithm N_ite, Max_ite_in = maxiter f1_new = f1 if method == 'split': f_f1, f_f2 = f1_new, f2 side = 0 Tpts0 = Tpts if kwargs.get('multires'): multires = True # multresolution in time Tpts_low = 5 Tpts0 = Tpts_low else: multires = False Coe_x = torch.zeros(Basis_Sph.size(0), Tpts0 - 2) RX = torch.zeros(3) for i in range(N_ite): if i == N_ite-1: Max_ite_in = 100 if Tpts0 != Tpts: Tpts0 = Tpts Coe_x = up_sample(Coe_x, Tpts - 2) if i > round(int(N_ite/2)): multires = False f_f1, Coe_x, RX, Energy = get_optCoe_shapes_split(f_f1, Coe_x, RX, f1_new, f2, Basis_vecFields, Basis_Sph, Basis_1forms, a, b, c, d, Tpts0, Max_ite_in, side, **{'multires': multires}) EnergyAll.append(Energy) # rotate f2 using R R = torch_expm(RX) Rf2 = torch.einsum("ij,jmn->imn",[R,f2]) # Rotate f_f2 Time_points = torch.arange(Tpts, out=torch.FloatTensor()) lin_f = f1_new + torch.einsum("imn,t->timn", [Rf2-f1_new, Time_points])/(Tpts-1) # perturbe the linear path using the optimal coefficients geo_f = torch.zeros(Tpts, 3, *f1.shape[-2:]) geo_f[0], geo_f[-1] = f_f1, Rf2 geo_f[1: Tpts - 1] = lin_f[1: Tpts - 1] + torch.einsum("limn,lt->timn", [Basis_Sph, Coe_x]) elif method == 'combined': f = f1_new Coe_x = torch.zeros(Basis_Sph.size(0), Tpts - 2) RX = torch.zeros(3) idty = get_idty_S2(*Basis_vecFields.shape[-2:]) for i in range(N_ite): if i == N_ite-1: Max_ite_in = 100 X_new, Coe_x, RX, Energy = get_optCoe_shapes_combined(f, Coe_x, RX, f1_new, f2, Basis_vecFields, Basis_1forms, a, b, c, d, Tpts, Max_ite_in) # update f gamma = idty + torch.einsum("i, ijkl->jkl", [X_new, Basis_vecFields]) gamma = gamma/torch.einsum("ijk, ijk->jk", [gamma, gamma]).sqrt() # project to S2 # get the spherical coordinate representation gammaSph = torch.stack((Cartesian_to_spherical(gamma[0], gamma[1], gamma[2]))) f = compose_gamma(f, gammaSph) EnergyAll.append(Energy) # get R(f2) R = torch_expm(RX) Rf2 = torch.einsum("ij,jmn->imn",[R,f2]) Time_points = torch.arange(Tpts, out=torch.FloatTensor()) lin_f = f1_new + torch.einsum("imn,t->timn", [Rf2-f1_new, Time_points])/(Tpts-1) # perturbe the linear path using the optimal coefficients geo_f = torch.zeros(Tpts, 3, *f1.shape[-2:]) geo_f[0], geo_f[Tpts - 1] = f, lin_f[-1] geo_f[1: Tpts - 1] = lin_f[1: Tpts - 1] + torch.einsum("limn,lt->timn", [Basis_Sph, Coe_x]) EnergyAll0 = [item for sublist in EnergyAll for item in sublist] return geo_f, EnergyAll0
def initialize_over_paraSO3_SRNF(f1, f2, idty): # load the elements in the icosahedral group XIco_mat = sio.loadmat('Bases/skewIcosahedral.mat') XIco = torch.from_numpy(XIco_mat['X']).float() q2 = f_to_q(f2) EIco = torch.zeros(60) f1_gammaIco = torch.zeros(60, *f1.size()) q1_gammaIco = torch.zeros(60, *f1.size()) for i in range(60): RIco = torch.einsum("ij,jmn->imn", [torch_expm(XIco[i]), idty]) gammaIco = torch.stack((Cartesian_to_spherical(RIco[0] + 1e-7, RIco[1], (1 - 1e-7) * RIco[2]))) f1_gammaIco[i] = compose_gamma(f1, gammaIco) q1_gammaIco[i] = f_to_q(f1_gammaIco[i]) EIco[i] = integral_over_s2( torch.einsum("imn,imn->mn", [q2 - q1_gammaIco[i], q2 - q1_gammaIco[i]])) # get the index of the smallest value Ind = np.argmin(EIco) X = XIco[Ind] L2_ESO3 = [] def opt(X): X = torch.from_numpy(X).float().requires_grad_() R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty]) gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1], (1 - 1e-7) * R[2]))) f1_gamma = compose_gamma(f1, gamma) q1_gamma = f_to_q(f1_gamma) q2 = f_to_q(f2) y = integral_over_s2( torch.einsum("imn,imn->mn", [q2 - q1_gamma, q2 - q1_gamma])) y.backward() return np.double(y.data.numpy()), np.double(X.grad.data.numpy()) def printx(x): L2_ESO3.append(opt(x)[0]) res = optimize.minimize(opt, X, method='BFGS', jac=True, callback=printx, options={ 'gtol': 1e-02, 'disp': False }) # True X_opt = torch.from_numpy(res.x).float() R_opt = torch.einsum("ij,jmn->imn", [torch_expm(X_opt), idty]) gamma_opt = torch.stack((Cartesian_to_spherical(R_opt[0] + 1e-7, R_opt[1], (1 - 1e-7) * R_opt[2]))) f1_gamma = compose_gamma(f1, gamma_opt) return f1_gamma, L2_ESO3, f1_gammaIco[Ind], EIco