Ejemplo n.º 1
0
def riemann_metric(df, w, v, a, b, c, d):  # df,w,v: 3*2*m*n
    A = torch.einsum("simn,sjmn->ijmn", [df, df])  # A = df^Tdf: 2*2*m*n
    detA = A[0, 0] * A[1, 1] - A[0, 1] * A[1,
                                           0] + 1e-7  # detA = det(df^Tdf): m*n
    invA = torch.zeros(A.size())
    invA[:, :, 1:df.size(-2) - 1] = multi_dim_inv_for_22mn(
        A[:, :, 1:df.size(-2) - 1],
        detA[1:df.size(-2) - 1])  # invA = (df^Tdf)^{-1} 2*2*m*n
    df_dfplus = torch.einsum("is...,st...,jt...->ij...", [df, invA, df])
    B1 = torch.einsum("ismn,stmn,ktmn,kimn->mn", [w, invA, v, df_dfplus])
    B2 = torch.einsum("ismn,tsmn,tkmn,klmn,jlmn,jimn->mn",
                      [invA, df, w, invA, df, v])
    B3 = torch.einsum("ismn,stmn,itmn->mn", [w, invA, df]) * torch.einsum(
        "ismn,stmn,itmn->mn", [v, invA, df])
    # 1: measures the change in metric
    part1 = (B1 + B2 - B3) / 2
    # 2: measure the change in volume density
    part2 = B3 / 2
    # 3: measures the change in normal direction
    part3 = torch.einsum("ismn,stmn,itmn->mn", [w, invA, v]) - B1
    # 4: the additional term which measures ?
    part4 = (B1 - B2) / 2

    inner_prod = (a * part1 + b * part2 + c * part3 + d * part4) * detA.sqrt()
    return integral_over_s2(inner_prod)
Ejemplo n.º 2
0
def E_L2_SRNFS(X, f1, f2, idty, Basis_vecFields):

    gamma = idty + torch.einsum("i, ijkl->jkl", [X, Basis_vecFields])
    gamma = gamma / torch.einsum("ijk, ijk->jk",
                                 [gamma, gamma]).sqrt()  # project to S2

    # get the spherical coordinate representation
    gammaSph = torch.stack((Cartesian_to_spherical(gamma[0] + 1e-7, gamma[1],
                                                   (1 - 1e-7) * gamma[2])))
    f1_gamma = compose_gamma(f1, gammaSph)

    q1_gamma = f_to_q(f1_gamma)
    q2 = f_to_q(f2)
    Diff = q2 - q1_gamma
    return integral_over_s2(torch.einsum("imn,imn->mn", [Diff, Diff]))
Ejemplo n.º 3
0
    def opt(X):
        X = torch.from_numpy(X).float().requires_grad_()
        R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty])

        gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1],
                                                    (1 - 1e-7) * R[2])))
        f1_gamma = compose_gamma(f1, gamma)

        q1_gamma = f_to_q(f1_gamma)
        q2 = f_to_q(f2)

        y = integral_over_s2(
            torch.einsum("imn,imn->mn", [q2 - q1_gamma, q2 - q1_gamma]))
        y.backward()
        return np.double(y.data.numpy()), np.double(X.grad.data.numpy())
Ejemplo n.º 4
0
def rescale_surf(f):
    df = f_to_df(f)
    n_f = torch.cross(df[:, 0], df[:, 1], dim=0)
    Norm_n = torch.einsum("imn,imn->mn", [n_f, n_f]).sqrt()

    return f / integral_over_s2(Norm_n).sqrt()
Ejemplo n.º 5
0
def initialize_over_paraSO3_SRNF(f1, f2, idty):

    # load the elements in the icosahedral group
    XIco_mat = sio.loadmat('Bases/skewIcosahedral.mat')
    XIco = torch.from_numpy(XIco_mat['X']).float()

    q2 = f_to_q(f2)

    EIco = torch.zeros(60)
    f1_gammaIco = torch.zeros(60, *f1.size())
    q1_gammaIco = torch.zeros(60, *f1.size())
    for i in range(60):
        RIco = torch.einsum("ij,jmn->imn", [torch_expm(XIco[i]), idty])
        gammaIco = torch.stack((Cartesian_to_spherical(RIco[0] + 1e-7, RIco[1],
                                                       (1 - 1e-7) * RIco[2])))
        f1_gammaIco[i] = compose_gamma(f1, gammaIco)
        q1_gammaIco[i] = f_to_q(f1_gammaIco[i])

        EIco[i] = integral_over_s2(
            torch.einsum("imn,imn->mn",
                         [q2 - q1_gammaIco[i], q2 - q1_gammaIco[i]]))

    # get the index of the smallest value
    Ind = np.argmin(EIco)

    X = XIco[Ind]

    L2_ESO3 = []

    def opt(X):
        X = torch.from_numpy(X).float().requires_grad_()
        R = torch.einsum("ij,jmn->imn", [torch_expm(X), idty])

        gamma = torch.stack((Cartesian_to_spherical(R[0] + 1e-7, R[1],
                                                    (1 - 1e-7) * R[2])))
        f1_gamma = compose_gamma(f1, gamma)

        q1_gamma = f_to_q(f1_gamma)
        q2 = f_to_q(f2)

        y = integral_over_s2(
            torch.einsum("imn,imn->mn", [q2 - q1_gamma, q2 - q1_gamma]))
        y.backward()
        return np.double(y.data.numpy()), np.double(X.grad.data.numpy())

    def printx(x):
        L2_ESO3.append(opt(x)[0])

    res = optimize.minimize(opt,
                            X,
                            method='BFGS',
                            jac=True,
                            callback=printx,
                            options={
                                'gtol': 1e-02,
                                'disp': False
                            })  # True

    X_opt = torch.from_numpy(res.x).float()
    R_opt = torch.einsum("ij,jmn->imn", [torch_expm(X_opt), idty])
    gamma_opt = torch.stack((Cartesian_to_spherical(R_opt[0] + 1e-7, R_opt[1],
                                                    (1 - 1e-7) * R_opt[2])))
    f1_gamma = compose_gamma(f1, gamma_opt)
    return f1_gamma, L2_ESO3, f1_gammaIco[Ind], EIco