Пример #1
0
def _s2_fft(x, for_grad, b_in, b_out):
    '''
    :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2)
    :return: [l * m, batch, complex] (b_out**2, nbatch, 2)
    '''
    nspec = b_out**2
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in,
                           nl=b_out,
                           weighted=not for_grad,
                           device_type=x.device.type,
                           device_index=x.device.index)
    cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch)

    x = torch.fft(x, 1)  # [batch, beta, m, complex]

    stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
    output = x.new_empty((nspec, nbatch, 2))
    cuda_kernel(block=(1024, 1, 1),
                grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1),
                args=[
                    x.contiguous().data_ptr(),
                    wigner.contiguous().data_ptr(),
                    output.data_ptr()
                ],
                stream=stream)
    # [l * m, batch, complex]

    return output
Пример #2
0
def _s2_ifft(x, for_grad, b_in, b_out):
    '''
    :param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
    :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(
        b_out,
        nl=b_in,
        weighted=for_grad,
        device_type=x.device.type,
        device_index=x.device.index)  # [beta, l * m] (2 * b_out - 1, nspec)
    cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch)

    stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
    output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
    cuda_kernel(block=(1024, 1, 1),
                grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2,
                                            1024), 1, 1),
                args=[x.data_ptr(),
                      wigner.data_ptr(),
                      output.data_ptr()],
                stream=stream)
    # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)

    output = torch.ifft(output, 1) * output.size(
        -2)  # [batch, beta, alpha, complex]

    return output
Пример #3
0
    def backward(self, gradz):  # pylint: disable=W
        x, y = self.saved_tensors
        nl = round(x.size(0)**0.5)
        nbatch = x.size(1)
        nfeature_in = x.size(2)
        nfeature_out = y.size(2)
        nspec = (4 * nl**2 - 1) * nl // 3

        gradx_cuda_kernel = _setup_s2mm_gradx_cuda_kernel(nbatch=nbatch, nspec=nspec, nl=nl, nfeature_in=nfeature_in, nfeature_out=nfeature_out)
        grady_cuda_kernel = _setup_s2mm_grady_cuda_kernel(nbatch=nbatch, nspec=nspec, nl=nl, nfeature_in=nfeature_in, nfeature_out=nfeature_out)

        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)

        gradx = grady = None

        if self.needs_input_grad[0]:
            gradx = gradz.new_empty((nl**2, nbatch, nfeature_in, 2))
            gradx_cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1),
                              grid=(cuda_utils.get_blocks(nl**2 * nbatch * nfeature_in, 1024), 1, 1),
                              args=[gradz.contiguous().data_ptr(), y.contiguous().data_ptr(), gradx.data_ptr()],
                              stream=stream)

        if self.needs_input_grad[1]:
            grady = gradz.new_empty((nl**2, nfeature_in, nfeature_out, 2))
            grady_cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1),
                              grid=(cuda_utils.get_blocks(nl**2 * nfeature_in * nfeature_out, 1024), 1, 1),
                              args=[gradz.contiguous().data_ptr(), x.contiguous().data_ptr(), grady.data_ptr()],
                              stream=stream)

        return gradx, grady
Пример #4
0
def s2_mm(x, y):
    '''
    :param x: [l * m,     batch,      feature_in,  complex]
    :param y: [l * m,     feature_in, feature_out, complex]
    :return:  [l * m * n, batch,      feature_out, complex]
    '''
    assert x.is_cuda and x.dtype == torch.float32
    assert y.is_cuda and y.dtype == torch.float32
    assert y.size(3) == 2
    assert x.size(3) == 2
    nbatch = x.size(1)
    nfeature_in = x.size(2)
    nfeature_out = y.size(2)
    assert y.size(1) == nfeature_in
    assert y.size(0) == x.size(0)
    nl = round(x.size(0)**0.5)
    nspec = (4 * nl**2 - 1) * nl // 3
    assert x.size(0) == nl ** 2
    assert y.size(0) == nl ** 2

    cuda_kernel = _setup_s2mm_cuda_kernel(nbatch=nbatch, nspec=nspec, nfeature_in=nfeature_in, nfeature_out=nfeature_out)

    stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
    output = x.new_empty((nspec, nbatch, nfeature_out, 2))
    cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1),
                grid=(cuda_utils.get_blocks(nspec * nbatch * nfeature_out, 1024), 1, 1),
                args=[x.contiguous().data_ptr(), y.contiguous().data_ptr(), output.data_ptr()],
                stream=stream)
    # [l * m * n, batch, feature_out, complex]

    return output
Пример #5
0
def s2_fft(x, for_grad=False, b_out=None):
    '''
    :param x: [..., beta, alpha, complex]
    :return:  [l * m, ..., complex]
    '''
    assert x.size(-1) == 2
    b_in = x.size(-2) // 2
    assert x.size(-2) == 2 * b_in
    assert x.size(-3) == 2 * b_in
    if b_out is None:
        b_out = b_in
    assert b_out <= b_in
    batch_size = x.size()[:-3]

    x = x.view(-1, 2 * b_in, 2 * b_in, 2)  # [batch, beta, alpha, complex]
    '''
    :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2)
    :return: [l * m, batch, complex] (b_out**2, nbatch, 2)
    '''
    nspec = b_out**2
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in,
                           nl=b_out,
                           weighted=not for_grad,
                           device=x.device)
    wigner = wigner.view(2 * b_in, -1)  # [beta, l * m] (2 * b_in, nspec)

    x = torch.view_as_real(torch.fft.fft(
        torch.view_as_complex(x)))  # [batch, beta, m, complex]

    output = x.new_empty((nspec, nbatch, 2))
    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in,
                                               nspec=nspec,
                                               nbatch=nbatch,
                                               device=x.device.index)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1),
                    args=[
                        x.contiguous().data_ptr(),
                        wigner.contiguous().data_ptr(),
                        output.data_ptr()
                    ],
                    stream=stream)
        # [l * m, batch, complex]
    else:
        for l in range(b_out):
            s = slice(l**2, l**2 + 2 * l + 1)
            xx = torch.cat(
                (x[:, :,
                   -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1]
            output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx))

    output = output.view(-1, *batch_size,
                         2)  # [l * m, ..., complex] (nspec, ..., 2)
    return output
Пример #6
0
def s2_ifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round(nspec**0.5)
    assert nspec == b_in**2
    if b_out is None:
        b_out = b_in
    assert b_out >= b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m, batch, complex] (nspec, nbatch, 2)
    '''
    :param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
    :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)
    wigner = wigner.view(2 * b_out, -1)  # [beta, l * m] (2 * b_out, nspec)

    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out,
                                                nl=b_in,
                                                nbatch=nbatch,
                                                device=x.device.index)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2,
                                                1024), 1, 1),
                    args=[x.data_ptr(),
                          wigner.data_ptr(),
                          output.data_ptr()],
                    stream=stream)
        # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
    else:
        output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
        for l in range(b_in):
            s = slice(l**2, l**2 + 2 * l + 1)
            out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
            output[:, :, :l + 1] += out[:, :, -l - 1:]
            if l > 0:
                output[:, :, -l:] += out[:, :, :l]

    output = torch.view_as_real(torch.fft.ifft(
        torch.view_as_complex(output))) * output.size(
            -2)  # [batch, beta, alpha, complex]
    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2)
    return output
Пример #7
0
def _s2_fft(x, for_grad, b_in, b_out):
    '''
    :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2)
    :return: [l * m, batch, complex] (b_out**2, nbatch, 2)
    '''
    nspec = b_out**2
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in,
                           nl=b_out,
                           weighted=not for_grad,
                           device_type=x.device.type,
                           device_index=x.device.index)
    wigner = wigner.view(2 * b_in, -1)  # [beta, l * m] (2 * b_in, nspec)

    x = torch.fft(x, 1)  # [batch, beta, m, complex]

    output = x.new_empty((nspec, nbatch, 2))
    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        device = torch.cuda.current_device()
        cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in,
                                               nspec=nspec,
                                               nbatch=nbatch,
                                               device=device)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1),
                    args=[
                        x.contiguous().data_ptr(),
                        wigner.contiguous().data_ptr(),
                        output.data_ptr()
                    ],
                    stream=stream)
        # [l * m, batch, complex]
    else:
        for l in range(b_out):
            s = slice(l**2, l**2 + 2 * l + 1)
            xx = torch.cat(
                (x[:, :,
                   -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1]
            output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx))

    return output
Пример #8
0
def _s2_ifft(x, for_grad, b_in, b_out):
    '''
    :param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
    :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out,
                           nl=b_in,
                           weighted=for_grad,
                           device_type=x.device.type,
                           device_index=x.device.index)
    wigner = wigner.view(2 * b_out, -1)  # [beta, l * m] (2 * b_out, nspec)

    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        device = torch.cuda.current_device()
        cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out,
                                                nl=b_in,
                                                nbatch=nbatch,
                                                device=device)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2,
                                                1024), 1, 1),
                    args=[x.data_ptr(),
                          wigner.data_ptr(),
                          output.data_ptr()],
                    stream=stream)
        # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
    else:
        output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
        for l in range(b_in):
            s = slice(l**2, l**2 + 2 * l + 1)
            out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
            output[:, :, :l + 1] += out[:, :, -l - 1:]
            if l > 0:
                output[:, :, -l:] += out[:, :, :l]

    output = torch.ifft(output, 1) * output.size(
        -2)  # [batch, beta, alpha, complex]

    return output
Пример #9
0
def _setup_so3ifft_cuda_kernel(b_in, b_out, nbatch, real_output, device=0):
    kernel = '''
#define B_IN {}
#define B_OUT {}
#define NSPEC {}
#define NBATCH {}
'''.format(b_in, b_out, b_in * (4 * b_in ** 2 - 1) // 3, nbatch)

    if real_output:
        kernel += '''
#define REAL_OUT
'''

    kernel += '''
#define MOD(i, n) (((i) + (n)) % (n))
#define MAX(x, y) ((x) < (y) ? (y) : (x))
#define CEIL_DIV(x, y) (((x) + (y) - 1) / (y))

extern "C"
__global__ void main_(const float* in, const float* wig, float* out)
{
    int m = (blockIdx.z / (2 * B_OUT - 1)) - (B_OUT - 1);
    int n = (blockIdx.z % (2 * B_OUT - 1)) - (B_OUT - 1);

#ifdef REAL_OUT
    if (n < 0 || (n == 0 && m < 0)) {
        return; // note: this return does not depend on threadIdx
    }
#endif

    int l_min = MAX(abs(m), abs(n));

    int batch = blockIdx.y * 32 + threadIdx.y;

    float sum_re = 0.0;
    float sum_im = 0.0;

    for (int tile = 0; tile < CEIL_DIV(B_IN - l_min, 32); ++tile) {
        __shared__ float tileA[2][32][32];
        __shared__ float tileB[32][32+1];

        int l = l_min + tile * 32 + threadIdx.x;
        int lmn = (4 * l*l - 1) * l / 3 + (l+m) * (2 * l + 1) + (l+n);
        int i = (lmn * NBATCH + batch) * 2;
        tileA[0][threadIdx.y][threadIdx.x] = l < B_IN && batch < NBATCH ? in[i + 0] : 0.0;
        tileA[1][threadIdx.y][threadIdx.x] = l < B_IN && batch < NBATCH ? in[i + 1] : 0.0;

        int beta = blockIdx.x * 32 + threadIdx.y;
        tileB[threadIdx.x][threadIdx.y] = l < B_IN && beta < 2*B_OUT ? wig[beta * NSPEC + lmn] : 0.0;

        __syncthreads();

        for (int l = 0; l < 32; ++l) {
            sum_re += tileA[0][threadIdx.y][l] * tileB[l][threadIdx.x];
            sum_im += tileA[1][threadIdx.y][l] * tileB[l][threadIdx.x];
        }

        __syncthreads();
    }

    int beta = blockIdx.x * 32 + threadIdx.x;

    if (beta < 2*B_OUT && batch < NBATCH) {
        int i = (((batch * 2*B_OUT + beta) * 2*B_OUT + MOD(m, 2*B_OUT)) * 2*B_OUT + MOD(n, 2*B_OUT)) * 2;
        out[i + 0] = sum_re;
        out[i + 1] = sum_im;

#ifdef REAL_OUT
        i = (((batch * 2*B_OUT + beta) * 2*B_OUT + MOD(-m, 2*B_OUT)) * 2*B_OUT + MOD(-n, 2*B_OUT)) * 2;
        out[i + 0] = sum_re;
        out[i + 1] = -sum_im;
#endif
    }
}
'''
    import s2cnn.utils.cuda as cuda_utils
    kernel = cuda_utils.compile_kernel(kernel, 'so3ifft.cu', 'main_')
    stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)

    def fun(x, wigner, output):
        output[:] = 0
        kernel(block=(32, 32, 1),
               grid=(math.ceil(2 * b_out / 32), math.ceil(nbatch / 32), (2 * b_out - 1) ** 2),
               args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()],
               stream=stream)

    return fun
Пример #10
0
def _setup_so3fft_cuda_kernel(b_in, b_out, nbatch, real_input, device=0):
    kernel = '''
#define B_IN {}
#define B_OUT {}
#define NSPEC {}
#define NBATCH {}
'''.format(b_in, b_out, b_out * (4 * b_out ** 2 - 1) // 3, nbatch)

    if real_input:
        kernel += '''
#define REAL_IN
'''

    kernel += '''
#define MOD(i, n) (((i) + (n)) % (n))
#define MAX(x, y) ((x) < (y) ? (y) : (x))
#define CEIL_DIV(x, y) (((x) + (y) - 1) / (y))

extern "C"
__global__ void main_(const float* in, const float* wig, float* out)
{
    // blockIdx = (l, batch, mn)
    // blockDim = (32, 32, 1)
    // threadIdx = (sub l, sub batch, 0)
    // gridDim = (b / 32, nbatch / 32, (2b-1)**2)
    int m = (blockIdx.z / (2 * B_OUT - 1)) - (B_OUT - 1);
    int n = (blockIdx.z % (2 * B_OUT - 1)) - (B_OUT - 1);

    int l_min = MAX(abs(m), abs(n));

    if (blockIdx.x * 32 + 31 < l_min) {
        // for blocks fully out of l-range
        return; // note: this return does not depend on threadIdx
    }

#ifdef REAL_IN
    if (n < 0 || (n == 0 && m < 0)) {
        return; // note: this return does not depend on threadIdx
    }
#endif

    int batch = blockIdx.y * 32 + threadIdx.y;
    int l = blockIdx.x * 32 + threadIdx.x;

    int lmn = (4 * l*l - 1) * l / 3 + (l+m) * (2 * l + 1) + (l+n);

    float sum_re = 0.0;
    float sum_im = 0.0;

    for (int tile = 0; tile < CEIL_DIV(2 * B_IN, 32); ++tile) {
        __shared__ float tileA[32][32][2];
        __shared__ float tileB[32][32];

        int beta = tile * 32 + threadIdx.x;
#ifdef REAL_IN
        // `in` shape is (NBATCH, 2 * B_IN, 2 * B_IN, B_IN + 1, 2)
        // http://www.fftw.org/fftw3_doc/Multi_002dDimensional-DFTs-of-Real-Data.html
        int i = (((batch * 2*B_IN + beta) * 2*B_IN + MOD(m, 2*B_IN)) * (B_IN + 1) + n) * 2;
#else
        int i = (((batch * 2*B_IN + beta) * 2*B_IN + MOD(m, 2*B_IN)) * 2*B_IN + MOD(n, 2*B_IN)) * 2;
#endif
        tileA[threadIdx.y][threadIdx.x][0] = beta < 2*B_IN && batch < NBATCH ? in[i + 0] : 0.0;
        tileA[threadIdx.y][threadIdx.x][1] = beta < 2*B_IN && batch < NBATCH ? in[i + 1] : 0.0;

        beta = tile * 32 + threadIdx.y;
        tileB[threadIdx.y][threadIdx.x] = beta < 2*B_IN && l_min <= l && l < B_OUT ? wig[beta * NSPEC + lmn] : 0.0;

        __syncthreads();

        for (int beta = 0; beta < 32; ++beta) {
            sum_re += tileA[threadIdx.y][beta][0] * tileB[beta][threadIdx.x];
            sum_im += tileA[threadIdx.y][beta][1] * tileB[beta][threadIdx.x];
        }

        __syncthreads();
    }

    // About this if: some blocks are used to compute but not to save the results
    if (l_min <= l && l < B_OUT && batch < NBATCH) {
        out[(lmn * NBATCH + batch) * 2 + 0] = sum_re;
        out[(lmn * NBATCH + batch) * 2 + 1] = sum_im;

#ifdef REAL_IN
        lmn = (4 * l*l - 1) * l / 3 + (l-m) * (2 * l + 1) + (l-n);
        float fudge = (m - n) % 2 == 0 ? 1.0 : -1.0;
        out[(lmn * NBATCH + batch) * 2 + 0] = fudge * sum_re;
        out[(lmn * NBATCH + batch) * 2 + 1] = -fudge * sum_im;
#endif
    }
}
'''
    import s2cnn.utils.cuda as cuda_utils
    kernel = cuda_utils.compile_kernel(kernel, 'so3fft.cu', 'main_')
    stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)

    def fun(x, wigner, output):
        assert output.is_contiguous()
        kernel(block=(32, 32, 1),
               grid=(math.ceil(b_out / 32), math.ceil(nbatch / 32), (2 * b_out - 1) ** 2),
               args=[x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr()],
               stream=stream)

    return fun
Пример #11
0
def _setup_so3mm_cuda_kernel(nl,
                             ni,
                             nj,
                             nk,
                             conj_x=False,
                             conj_y=False,
                             trans_x_spec=False,
                             trans_x_feature=False,
                             trans_y_spec=False,
                             trans_y_feature=False,
                             trans_out_feature=False,
                             device=0):
    '''
    return a function that computes
        out[l*m*n, i, j] = sum_k sum_p x[l*m*p, i, k] y[l*p*n, k, j]
    where out, x, y are complex valued

    if conj_x is set to True, x is conjugated
    if conj_y is set to True, y is conjugated
    if trans_x_spec is set to True m and p are permuted in x[...]
    if trans_y_spec is set to True p and n are permuted in y[...]
    if trans_x_feature is set to True i and k are permuted in x[...]
    if trans_y_feature is set to True k and j are permuted in y[...]
    if trans_out_feature is set to True i and j are permuted in out[...]
    '''

    kernel = '''
#define NI {}
#define NJ {}
#define NK {}
'''.format(ni, nj, nk)

    if not trans_x_spec and not trans_x_feature:
        kernel += '#define INDEX_X (((L0 + m * L + p) * NI + i) * NK + k)\n'
    if not trans_x_spec and trans_x_feature:
        kernel += '#define INDEX_X (((L0 + m * L + p) * NK + k) * NI + i)\n'
    if trans_x_spec and not trans_x_feature:
        kernel += '#define INDEX_X (((L0 + p * L + m) * NI + i) * NK + k)\n'
    if trans_x_spec and trans_x_feature:
        kernel += '#define INDEX_X (((L0 + p * L + m) * NK + k) * NI + i)\n'

    if not trans_y_spec and not trans_y_feature:
        kernel += '#define INDEX_Y (((L0 + p * L + n) * NK + k) * NJ + j)\n'
    if not trans_y_spec and trans_y_feature:
        kernel += '#define INDEX_Y (((L0 + p * L + n) * NJ + j) * NK + k)\n'
    if trans_y_spec and not trans_y_feature:
        kernel += '#define INDEX_Y (((L0 + n * L + p) * NK + k) * NJ + j)\n'
    if trans_y_spec and trans_y_feature:
        kernel += '#define INDEX_Y (((L0 + n * L + p) * NJ + j) * NK + k)\n'

    if not trans_out_feature:
        kernel += '#define INDEX_OUT (((L0 + m * L + n) * NI + i) * NJ + j)\n'
    if trans_out_feature:
        kernel += '#define INDEX_OUT (((L0 + m * L + n) * NJ + j) * NI + i)\n'

    kernel += '''
#define CONJ_X {}
#define CONJ_Y {}
'''.format("x_im = -x_im;" if conj_x else ";",
           "y_im = -y_im;" if conj_y else ";")

    kernel += '''
#define CEIL_DIV(x, y) (((x) + (y) - 1) / (y))

extern "C"
__global__ void main_(const float* in_x, const float* in_y, float* out)
{
    // start of thread independant code
    int l = blockIdx.z;
    int L = 2 * l + 1;
    int L0 = (4 * l*l - 1) * l / 3;

    if (blockIdx.y * 32 >= L * NI || blockIdx.x * 32 >= L * NJ) {
        return;
    }

    int ntile = CEIL_DIV(L * NK, 32);
    // end of thread independant code

    int mi = blockIdx.y * 32 + threadIdx.y;
    int m = mi / NI;
    int i = mi % NI;
    int nj = blockIdx.x * 32 + threadIdx.x;
    int n = nj / NJ;
    int j = nj % NJ;

    float sum_re = 0.0;
    float sum_im = 0.0;

    for (int tile = 0; tile < ntile; ++tile) {
        __shared__ float tileX[2][32][32];
        __shared__ float tileY[2][32][32];

        int pk = tile * 32 + threadIdx.x;
        int p = pk / NK;
        int k = pk % NK;
        int index = INDEX_X * 2;
        tileX[0][threadIdx.y][threadIdx.x] = m < L && p < L ? in_x[index + 0] : 0.0;
        tileX[1][threadIdx.y][threadIdx.x] = m < L && p < L ? in_x[index + 1] : 0.0;

        pk = tile * 32 + threadIdx.y;
        p = pk / NK;
        k = pk % NK;
        index = INDEX_Y * 2;
        tileY[0][threadIdx.y][threadIdx.x] = p < L && n < L ? in_y[index + 0] : 0.0;
        tileY[1][threadIdx.y][threadIdx.x] = p < L && n < L ? in_y[index + 1] : 0.0;

        __syncthreads();

        for (int any = 0; any < 32; ++any) {
            float x_re = tileX[0][threadIdx.y][any];
            float x_im = tileX[1][threadIdx.y][any];
            float y_re = tileY[0][any][threadIdx.x];
            float y_im = tileY[1][any][threadIdx.x];

            CONJ_X
            CONJ_Y

            sum_re += x_re * y_re - x_im * y_im;
            sum_im += x_re * y_im + x_im * y_re;
        }

        __syncthreads();
    }

    if (m < L && n < L) {
        int index = INDEX_OUT * 2;
        out[index + 0] = sum_re;
        out[index + 1] = sum_im;
    }
}
'''
    import s2cnn.utils.cuda as cuda_utils
    kernel = cuda_utils.compile_kernel(kernel, b'so3_mm.cu', 'main_')
    stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)

    def fun(x, y, output):
        assert output.is_contiguous()
        kernel(block=(32, 32, 1),
               grid=(math.ceil(
                   (2 * nl - 1) * nj / 32), math.ceil(
                       (2 * nl - 1) * ni / 32), nl),
               args=[
                   x.contiguous().data_ptr(),
                   y.contiguous().data_ptr(),
                   output.data_ptr()
               ],
               stream=stream)

    return fun