コード例 #1
0
def so3_mm(x, y):
    '''
    :param x: [l * m * n,   batch,    feature_in,  complex]
    :param y: [l * m * n, feature_in, feature_out, complex]
    :return:  [l * m * n,   batch,    feature_out, complex]
    '''
    from s2cnn.utils.complex import complex_mm
    import math

    assert y.size(3) == 2
    assert x.size(3) == 2
    nbatch = x.size(1)
    nfeature_in = x.size(2)
    nfeature_out = y.size(2)
    assert y.size(1) == nfeature_in
    nspec = x.size(0)
    assert y.size(0) == nspec
    nl = math.ceil((3 / 4 * nspec)**(1 / 3))
    assert nspec == nl * (4 * nl**2 - 1) // 3

    Fz_list = []
    begin = 0
    for l in range(nl):
        L = 2 * l + 1
        size = L**2

        Fx = x[begin:begin + size]  # [m * n,   batch,    feature_in,  complex]
        Fy = y[begin:begin + size]  # [m * n, feature_in, feature_out, complex]

        Fx = Fx.view(L, L, nbatch, nfeature_in,
                     2)  # [m, n, batch, feature_in, complex]
        Fx = Fx.transpose(0, 1)  # [n, m, batch, feature_in, complex]
        Fx = Fx.transpose(0, 2)  # [batch, m, n, feature_in, complex]
        Fx = Fx.transpose(2, 3)  # [batch, m, feature_in, n, complex]
        Fx = Fx.contiguous()
        Fx = Fx.view(nbatch * L, nfeature_in * L,
                     2)  # [batch * m, feature_in * n, complex]

        Fy = Fy.view(L, L, nfeature_in, nfeature_out,
                     2)  # [m, n, feature_in, feature_out, complex]
        Fy = Fy.transpose(0, 2)  # [feature_in, n, m, feature_out, complex]
        Fy = Fy.contiguous()
        Fy = Fy.view(nfeature_in * L, L * nfeature_out,
                     2)  # [feature_in * n, m * feature_out, complex]

        Fz = complex_mm(
            Fx, Fy, conj_y=True
        )  # [batch * m_x, m_y * feature_out, complex] m_x -> m, m_y -> n
        Fz = Fz.view(nbatch, L * L, nfeature_out,
                     2)  # [batch, m * n, feature_out, complex]
        Fz = Fz.transpose(0, 1)  # [m * n, batch, feature_out, complex]

        Fz_list.append(Fz)

        begin += size

    z = torch.cat(Fz_list, 0)  # [l * m * n, batch, feature_out, complex]
    return z
コード例 #2
0
def s2_mm(x, y):
    '''
    :param x: [l * m,     batch,      feature_in,  complex]
    :param y: [l * m,     feature_in, feature_out, complex]
    :return:  [l * m * n, batch,      feature_out, complex]
    '''
    from s2cnn.utils.complex import complex_mm

    assert y.size(3) == 2
    assert x.size(3) == 2
    nbatch = x.size(1)
    nfeature_in = x.size(2)
    nfeature_out = y.size(2)
    assert y.size(1) == nfeature_in
    nspec = x.size(0)
    assert y.size(0) == nspec

    if x.is_cuda:
        return _cuda_S2_mm()(x, y)

    nl = round(nspec**0.5)

    Fz_list = []
    begin = 0
    for l in range(nl):
        L = 2 * l + 1
        size = L

        Fx = x[begin:begin + size]  # [m, batch,      feature_in,  complex]
        Fy = y[begin:begin + size]  # [m, feature_in, feature_out, complex]

        Fx = Fx.view(L * nbatch, nfeature_in,
                     2)  # [m * batch, feature_in, complex]

        Fy = Fy.transpose(0, 1)  # [feature_in, m, feature_out, complex]
        Fy = Fy.contiguous()
        Fy = Fy.view(nfeature_in, L * nfeature_out,
                     2)  # [feature_in, m * feature_out, complex]

        Fz = complex_mm(
            Fx, Fy, conj_y=True
        )  # [m_x * batch, m_y * feature_out, complex] m_x -> m, m_y -> n
        Fz = Fz.view(L, nbatch, L, nfeature_out,
                     2)  # [m, batch, n, feature_out, complex]
        Fz = Fz.transpose(1, 2)  # [m, n, batch, feature_out, complex]
        Fz = Fz.contiguous()
        Fz = Fz.view(L * L, nbatch, nfeature_out,
                     2)  # [m * n, batch, feature_out, complex]

        Fz_list.append(Fz)

        begin += size

    z = torch.cat(Fz_list, 0)  # [l * m * n, batch, feature_out, complex]
    return z
コード例 #3
0
ファイル: so3_rotation.py プロジェクト: xiangliu886/s2cnn
def so3_rotation(x, alpha, beta, gamma):
    '''
    :param x: [..., beta, alpha, gamma] (..., 2b, 2b, 2b)
    '''
    b = x.size()[-1] // 2
    x_size = x.size()

    Us = _setup_so3_rotation(b,
                             alpha,
                             beta,
                             gamma,
                             device_type=x.device.type,
                             device_index=x.device.index)

    # fourier transform
    x = SO3_fft_real()(x)  # [l * m * n, ..., complex]

    # rotated spectrum
    Fz_list = []
    begin = 0
    for l in range(b):
        L = 2 * l + 1
        size = L**2

        Fx = x[begin:begin + size]
        Fx = Fx.view(L, -1, 2)  # [m, n * batch, complex]

        U = Us[l].view(L, L, 2)  # [m, n, complex]

        Fz = complex_mm(U, Fx, conj_x=True)  # [m, n * batch, complex]

        Fz = Fz.view(size, -1, 2)  # [m * n, batch, complex]
        Fz_list.append(Fz)

        begin += size

    Fz = torch.cat(Fz_list, 0)  # [l * m * n, batch, complex]
    z = SO3_ifft_real()(Fz)

    z = z.contiguous()
    z = z.view(*x_size)

    return z