def energy_func_para_split_modSO3(f_f1, Coe_x, RX, f1, f2, Basis_1forms, a, b, c, d, Tstps): # rotate f_f2 using R R = torch_expm(RX) Rf2 = torch.einsum("ij,jmn->imn", [R, f2]) # Rotate f_f2 df1 = f_to_df(f1) dRf2 = f_to_df(Rf2) Time_points = torch.arange(Tstps, out=torch.FloatTensor()) # find the linear path between df1 and df2 curve0 = df1 + torch.einsum("ijmn,t->tijmn", [dRf2 - df1, Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t # calculate the new path between the new endpoints using the optimal coefficient matrix curve = torch.zeros(curve0.size()) curve[0], curve[-1] = f_to_df(f_f1), curve0[-1] curve[1:Tstps - 1] = curve0[1:Tstps - 1] + torch.einsum( "lijmn,lt->tijmn", [Basis_1forms, Coe_x]) d_curve = (curve[1:Tstps] - curve[0:Tstps - 1]) * (Tstps - 1) E = torch.zeros(Tstps - 1) for i in range(Tstps - 1): E[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d) return torch.sum(E) / (Tstps - 1)
def compute_optCoe_Imm(Coe_x, f1, f2, a, b, c, d, Basis_1forms, Tpts, Max_ite): Num_basis = Basis_1forms.size(0) df1 = f_to_df(f1) df2 = f_to_df(f2) Time_points = torch.arange(Tpts, out=torch.FloatTensor()) # find the linear path between df1 and df2 curve0 = df1 + torch.einsum("ijmn,t->tijmn", [df2 - df1, Time_points]) / (Tpts - 1) # df1 + (df2-df1)t # define the function to be optimized def Opts_Imm(Coe_x): Coe_x = torch.from_numpy(Coe_x).float() Coe_x = Coe_x.view(Num_basis, -1) Coe_x.requires_grad_() y = energy_func_Imm(Coe_x, curve0, Basis_1forms, a, b, c, d) y.backward() return np.double(y.data.numpy()), np.double(Coe_x.grad.data.numpy().flatten()) # define callback to show the details for each iteration in the optimization process def printx(x): AllEnergy.append(Opts_Imm(x)[0]) AllEnergy = [] res = optimize.minimize(Opts_Imm, Coe_x.flatten(), method='BFGS', jac=True, callback=printx, options={'gtol': 1e-02, 'disp': False, 'maxiter': Max_ite}) coe_x_opt = torch.from_numpy(res.x).float().view(Num_basis, -1) return coe_x_opt, AllEnergy
def energy_reg(X, c_f, idty, Basis_vecFields, a, b, c, d): f1, f2 = c_f[0], c_f[1] gamma = idty + torch.einsum("i, ijkl->jkl", [X, Basis_vecFields]) gamma = gamma / torch.einsum("ijk, ijk->jk", [gamma, gamma]).sqrt() # project to S2 # get the spherical coordinate representation gammaSph = torch.stack((Cartesian_to_spherical(gamma[0] + 1e-7, gamma[1], (1 - 1e-7) * gamma[2]))) f1_gamma = compose_gamma(f1, gammaSph) df1_gamma = f_to_df(f1_gamma) Diff_12 = f_to_df(f2 - f1_gamma) E = riemann_metric(df1_gamma, Diff_12, Diff_12, a, b, c, d) # /dT return E
def opt(X): X = torch.from_numpy(X).float().requires_grad_() R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty]) gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1], (1 - 1e-7) * R[2]))) f1_gamma = compose_gamma(f1, gamma) df1_gamma = f_to_df(f1_gamma) df2 = f_to_df(f2) Time_points = torch.arange(Tstps, out=torch.FloatTensor()) lin_path = df1_gamma + torch.einsum("ijmn,t->tijmn", [df2 - df1_gamma, Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t y = length_dfunc_Imm(lin_path, a, b, c, d) y.backward() return np.double(y.data.numpy()), np.double(X.grad.data.numpy())
def length_func_surf_Imm(curve_f, a, b, c, d): # size of curve: T*3*2*m*n curve = torch.zeros(curve_f.size(0), 3, 2, *curve_f.shape[-2:]) for i in range(curve_f.size(0)): curve[i] = f_to_df(curve_f[i]) d_curve = (curve[1:curve.size(0)] - curve[0:curve.size(0) - 1]) * (curve.size(0) - 1) L = torch.zeros(curve.size(0) - 1) for i in range(curve.size(0) - 1): L[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d).sqrt() return torch.sum(L) / (curve.size(0) - 1)
def energy_fun_combined_modSO3(f, X, Coe_x, RX, f1, f2, Tstps, Basis_vecFields, Basis_1forms, a, b, c, d): # f1 and f2 are the original boundary surfaces # calculate R(f2) R = torch_expm(RX) Rf2 = torch.einsum("ij,jmn->imn", [R, f2]) # Rotate f2 df1, dRf2 = f_to_df(f1), f_to_df(Rf2) # compute the linear path between df1 and df2 Time_points = torch.arange(Tstps, out=torch.FloatTensor()) curve0 = df1 + torch.einsum("ijmn,t->tijmn", [dRf2 - df1, Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t # get the identity map from S2 to S2 idty = get_idty_S2(*Basis_vecFields.shape[-2:]) # calculate f(gamma) gamma = idty + torch.einsum("i, ijkl->jkl", [X, Basis_vecFields]) gamma = gamma / torch.einsum("ijk, ijk->jk", [gamma, gamma]).sqrt() # project to S2 # get the spherical coordinate representation gammaSph = torch.stack((Cartesian_to_spherical(gamma[0] + 1e-7, gamma[1], (1 - 1e-7) * gamma[2]))) f_gamma = compose_gamma(f, gammaSph) # get a new curve that goes through all possibilities curve = torch.zeros(curve0.size()) curve[0], curve[curve0.size(0) - 1] = f_to_df(f_gamma), curve0[-1] curve[1:curve0.size(0) - 1] = curve0[1:curve0.size(0) - 1] + torch.einsum( "lijmn,lt->tijmn", [Basis_1forms, Coe_x]) d_curve = (curve[1:curve0.size(0)] - curve[0:curve0.size(0) - 1]) * (curve0.size(0) - 1) E = torch.zeros(curve0.size(0) - 1) for i in range(curve0.size(0) - 1): E[i] = riemann_metric(curve[i], d_curve[i], d_curve[i], a, b, c, d) return torch.sum(E) / (curve0.size(0) - 1)
def initialize_overSO3(f1, f2): df1 = f_to_df(f1) df2 = f_to_df(f2) n1 = oneForm_to_n(df1).view(3, -1) n1 = n1 - torch.sum(n1, 1).view(3, -1).expand(-1, n1.size(1)) / n1.size(1) n2 = oneForm_to_n(df2).view(3, -1) n2 = n2 - torch.sum(n2, 1).view(3, -1).expand(-1, n2.size(1)) / n2.size(1) # print(torch.sum(n1, 1)/n1.size(1), torch.sum(n2, 1)/n1.size(1)) Rz = torch.tensor([[-1, 0., 0], [0, -1, 0], [0, 0, 1]]) # z fixed rotate xy Rx = torch.tensor([[1, 0., 0], [0, -1, 0], [0, 0, -1]]) # x fixed rotate yz Ry = torch.tensor([[-1, 0., 0], [0, 1, 0], [0, 0, -1]]) # y fixed rotate xz n10, n20 = n1.view(3, -1), n2.view(3, -1) U1, D1, V1 = torch.svd(n10) U2, D2, V2 = torch.svd(n20) D0 = torch.eye(3) D0[2, 2] = torch.sign(torch.det(U1 @ torch.inverse(U2))) R = U1 @ D0 @ torch.inverse(U2) allR = torch.stack((R, Rz @ R, Rx @ R, Ry @ R)) allf2 = torch.einsum("lij,jmn->limn", [allR, f2]) L = np.zeros(4) for i in range(4): Diff = f1 - allf2[i] L[i] = torch.norm(Diff.flatten()).numpy() Ind = np.argmin(L) return allf2[Ind], L
def initialize_over_paraSO3(f1, f2, idty, a, b, c, d): Tstps = 2 # load the elements in the icosahedral group XIco_mat = sio.loadmat('Bases/skewIcosahedral.mat') XIco = torch.from_numpy(XIco_mat['X']).float() df2 = f_to_df(f2) Time_points = torch.arange(Tstps, out=torch.FloatTensor()) linear_path = torch.zeros(60, Tstps, 3, 2, *f1.shape[-2:]) f1_gammaIco = torch.zeros(60, *f1.size()) df1_gammaIco = torch.zeros(60, 3, 2, *f1.shape[-2:]) for i in range(60): RIco = torch.einsum("ij,jmn->imn", [torch_expm(XIco[i]), idty]) gammaIco = torch.stack((Cartesian_to_spherical(RIco[0] + 1e-7, RIco[1], (1 - 1e-7) * RIco[2]))) f1_gammaIco[i] = compose_gamma(f1, gammaIco) df1_gammaIco[i] = f_to_df(f1_gammaIco[i]) linear_path[i] = df1_gammaIco[i] + torch.einsum( "ijmn,t->tijmn", [df2 - df1_gammaIco[i], Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t length_linear = torch.zeros(60) for i in range(60): length_linear[i] = length_dfunc_Imm(linear_path[i], a, b, c, d) # get the index of the smallest value Ind = np.argmin(length_linear) X = XIco[Ind] L = [] def opt(X): X = torch.from_numpy(X).float().requires_grad_() R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty]) gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1], (1 - 1e-7) * R[2]))) f1_gamma = compose_gamma(f1, gamma) df1_gamma = f_to_df(f1_gamma) df2 = f_to_df(f2) Time_points = torch.arange(Tstps, out=torch.FloatTensor()) lin_path = df1_gamma + torch.einsum("ijmn,t->tijmn", [df2 - df1_gamma, Time_points]) / ( Tstps - 1) # df1 + (df2-df1)t y = length_dfunc_Imm(lin_path, a, b, c, d) y.backward() return np.double(y.data.numpy()), np.double(X.grad.data.numpy()) def printx(x): L.append(opt(x)[0]) res = optimize.minimize(opt, X, method='BFGS', jac=True, callback=printx, options={ 'gtol': 1e-02, 'disp': False }) # True X_opt = torch.from_numpy(res.x).float() R_opt = torch.einsum("ij,jmn->imn", [torch_expm(X_opt), idty]) gamma_opt = torch.stack((Cartesian_to_spherical(R_opt[0] + 1e-7, R_opt[1], (1 - 1e-7) * R_opt[2]))) f1_gamma = compose_gamma(f1, gamma_opt) return f1_gamma, L, f1_gammaIco[Ind], length_linear
def rescale_surf(f): df = f_to_df(f) n_f = torch.cross(df[:, 0], df[:, 1], dim=0) Norm_n = torch.einsum("imn,imn->mn", [n_f, n_f]).sqrt() return f / integral_over_s2(Norm_n).sqrt()