def _s2_fft(x, for_grad, b_in, b_out): ''' :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2) :return: [l * m, batch, complex] (b_out**2, nbatch, 2) ''' nspec = b_out**2 nbatch = x.size(0) wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device_type=x.device.type, device_index=x.device.index) cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch) x = torch.fft(x, 1) # [batch, beta, m, complex] stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nspec, nbatch, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1), args=[ x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr() ], stream=stream) # [l * m, batch, complex] return output
def _s2_ifft(x, for_grad, b_in, b_out): ''' :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner( b_out, nl=b_in, weighted=for_grad, device_type=x.device.type, device_index=x.device.index) # [beta, l * m] (2 * b_out - 1, nspec) cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2, 1024), 1, 1), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) output = torch.ifft(output, 1) * output.size( -2) # [batch, beta, alpha, complex] return output
def backward(self, gradz): # pylint: disable=W x, y = self.saved_tensors nl = round(x.size(0)**0.5) nbatch = x.size(1) nfeature_in = x.size(2) nfeature_out = y.size(2) nspec = (4 * nl**2 - 1) * nl // 3 gradx_cuda_kernel = _setup_s2mm_gradx_cuda_kernel(nbatch=nbatch, nspec=nspec, nl=nl, nfeature_in=nfeature_in, nfeature_out=nfeature_out) grady_cuda_kernel = _setup_s2mm_grady_cuda_kernel(nbatch=nbatch, nspec=nspec, nl=nl, nfeature_in=nfeature_in, nfeature_out=nfeature_out) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) gradx = grady = None if self.needs_input_grad[0]: gradx = gradz.new_empty((nl**2, nbatch, nfeature_in, 2)) gradx_cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), grid=(cuda_utils.get_blocks(nl**2 * nbatch * nfeature_in, 1024), 1, 1), args=[gradz.contiguous().data_ptr(), y.contiguous().data_ptr(), gradx.data_ptr()], stream=stream) if self.needs_input_grad[1]: grady = gradz.new_empty((nl**2, nfeature_in, nfeature_out, 2)) grady_cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), grid=(cuda_utils.get_blocks(nl**2 * nfeature_in * nfeature_out, 1024), 1, 1), args=[gradz.contiguous().data_ptr(), x.contiguous().data_ptr(), grady.data_ptr()], stream=stream) return gradx, grady
def s2_mm(x, y): ''' :param x: [l * m, batch, feature_in, complex] :param y: [l * m, feature_in, feature_out, complex] :return: [l * m * n, batch, feature_out, complex] ''' assert x.is_cuda and x.dtype == torch.float32 assert y.is_cuda and y.dtype == torch.float32 assert y.size(3) == 2 assert x.size(3) == 2 nbatch = x.size(1) nfeature_in = x.size(2) nfeature_out = y.size(2) assert y.size(1) == nfeature_in assert y.size(0) == x.size(0) nl = round(x.size(0)**0.5) nspec = (4 * nl**2 - 1) * nl // 3 assert x.size(0) == nl ** 2 assert y.size(0) == nl ** 2 cuda_kernel = _setup_s2mm_cuda_kernel(nbatch=nbatch, nspec=nspec, nfeature_in=nfeature_in, nfeature_out=nfeature_out) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nspec, nbatch, nfeature_out, 2)) cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch * nfeature_out, 1024), 1, 1), args=[x.contiguous().data_ptr(), y.contiguous().data_ptr(), output.data_ptr()], stream=stream) # [l * m * n, batch, feature_out, complex] return output
def s2_fft(x, for_grad=False, b_out=None): ''' :param x: [..., beta, alpha, complex] :return: [l * m, ..., complex] ''' assert x.size(-1) == 2 b_in = x.size(-2) // 2 assert x.size(-2) == 2 * b_in assert x.size(-3) == 2 * b_in if b_out is None: b_out = b_in assert b_out <= b_in batch_size = x.size()[:-3] x = x.view(-1, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, complex] ''' :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2) :return: [l * m, batch, complex] (b_out**2, nbatch, 2) ''' nspec = b_out**2 nbatch = x.size(0) wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) wigner = wigner.view(2 * b_in, -1) # [beta, l * m] (2 * b_in, nspec) x = torch.view_as_real(torch.fft.fft( torch.view_as_complex(x))) # [batch, beta, m, complex] output = x.new_empty((nspec, nbatch, 2)) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=x.device.index) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1), args=[ x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr() ], stream=stream) # [l * m, batch, complex] else: for l in range(b_out): s = slice(l**2, l**2 + 2 * l + 1) xx = torch.cat( (x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1] output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx)) output = output.view(-1, *batch_size, 2) # [l * m, ..., complex] (nspec, ..., 2) return output
def s2_ifft(x, for_grad=False, b_out=None): ''' :param x: [l * m, ..., complex] ''' assert x.size(-1) == 2 nspec = x.size(0) b_in = round(nspec**0.5) assert nspec == b_in**2 if b_out is None: b_out = b_in assert b_out >= b_in batch_size = x.size()[1:-1] x = x.view(nspec, -1, 2) # [l * m, batch, complex] (nspec, nbatch, 2) ''' :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2, 1024), 1, 1), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) else: output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2)) for l in range(b_in): s = slice(l**2, l**2 + 2 * l + 1) out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s])) output[:, :, :l + 1] += out[:, :, -l - 1:] if l > 0: output[:, :, -l:] += out[:, :, :l] output = torch.view_as_real(torch.fft.ifft( torch.view_as_complex(output))) * output.size( -2) # [batch, beta, alpha, complex] output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2) return output
def _s2_fft(x, for_grad, b_in, b_out): ''' :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2) :return: [l * m, batch, complex] (b_out**2, nbatch, 2) ''' nspec = b_out**2 nbatch = x.size(0) wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device_type=x.device.type, device_index=x.device.index) wigner = wigner.view(2 * b_in, -1) # [beta, l * m] (2 * b_in, nspec) x = torch.fft(x, 1) # [batch, beta, m, complex] output = x.new_empty((nspec, nbatch, 2)) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils device = torch.cuda.current_device() cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=device) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1), args=[ x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr() ], stream=stream) # [l * m, batch, complex] else: for l in range(b_out): s = slice(l**2, l**2 + 2 * l + 1) xx = torch.cat( (x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1] output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx)) return output
def _s2_ifft(x, for_grad, b_in, b_out): ''' :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device_type=x.device.type, device_index=x.device.index) wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils device = torch.cuda.current_device() cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=device) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2, 1024), 1, 1), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) else: output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2)) for l in range(b_in): s = slice(l**2, l**2 + 2 * l + 1) out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s])) output[:, :, :l + 1] += out[:, :, -l - 1:] if l > 0: output[:, :, -l:] += out[:, :, :l] output = torch.ifft(output, 1) * output.size( -2) # [batch, beta, alpha, complex] return output
def _setup_so3ifft_cuda_kernel(b_in, b_out, nbatch, real_output, device=0): kernel = ''' #define B_IN {} #define B_OUT {} #define NSPEC {} #define NBATCH {} '''.format(b_in, b_out, b_in * (4 * b_in ** 2 - 1) // 3, nbatch) if real_output: kernel += ''' #define REAL_OUT ''' kernel += ''' #define MOD(i, n) (((i) + (n)) % (n)) #define MAX(x, y) ((x) < (y) ? (y) : (x)) #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) extern "C" __global__ void main_(const float* in, const float* wig, float* out) { int m = (blockIdx.z / (2 * B_OUT - 1)) - (B_OUT - 1); int n = (blockIdx.z % (2 * B_OUT - 1)) - (B_OUT - 1); #ifdef REAL_OUT if (n < 0 || (n == 0 && m < 0)) { return; // note: this return does not depend on threadIdx } #endif int l_min = MAX(abs(m), abs(n)); int batch = blockIdx.y * 32 + threadIdx.y; float sum_re = 0.0; float sum_im = 0.0; for (int tile = 0; tile < CEIL_DIV(B_IN - l_min, 32); ++tile) { __shared__ float tileA[2][32][32]; __shared__ float tileB[32][32+1]; int l = l_min + tile * 32 + threadIdx.x; int lmn = (4 * l*l - 1) * l / 3 + (l+m) * (2 * l + 1) + (l+n); int i = (lmn * NBATCH + batch) * 2; tileA[0][threadIdx.y][threadIdx.x] = l < B_IN && batch < NBATCH ? in[i + 0] : 0.0; tileA[1][threadIdx.y][threadIdx.x] = l < B_IN && batch < NBATCH ? in[i + 1] : 0.0; int beta = blockIdx.x * 32 + threadIdx.y; tileB[threadIdx.x][threadIdx.y] = l < B_IN && beta < 2*B_OUT ? wig[beta * NSPEC + lmn] : 0.0; __syncthreads(); for (int l = 0; l < 32; ++l) { sum_re += tileA[0][threadIdx.y][l] * tileB[l][threadIdx.x]; sum_im += tileA[1][threadIdx.y][l] * tileB[l][threadIdx.x]; } __syncthreads(); } int beta = blockIdx.x * 32 + threadIdx.x; if (beta < 2*B_OUT && batch < NBATCH) { int i = (((batch * 2*B_OUT + beta) * 2*B_OUT + MOD(m, 2*B_OUT)) * 2*B_OUT + MOD(n, 2*B_OUT)) * 2; out[i + 0] = sum_re; out[i + 1] = sum_im; #ifdef REAL_OUT i = (((batch * 2*B_OUT + beta) * 2*B_OUT + MOD(-m, 2*B_OUT)) * 2*B_OUT + MOD(-n, 2*B_OUT)) * 2; out[i + 0] = sum_re; out[i + 1] = -sum_im; #endif } } ''' import s2cnn.utils.cuda as cuda_utils kernel = cuda_utils.compile_kernel(kernel, 'so3ifft.cu', 'main_') stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) def fun(x, wigner, output): output[:] = 0 kernel(block=(32, 32, 1), grid=(math.ceil(2 * b_out / 32), math.ceil(nbatch / 32), (2 * b_out - 1) ** 2), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) return fun
def _setup_so3fft_cuda_kernel(b_in, b_out, nbatch, real_input, device=0): kernel = ''' #define B_IN {} #define B_OUT {} #define NSPEC {} #define NBATCH {} '''.format(b_in, b_out, b_out * (4 * b_out ** 2 - 1) // 3, nbatch) if real_input: kernel += ''' #define REAL_IN ''' kernel += ''' #define MOD(i, n) (((i) + (n)) % (n)) #define MAX(x, y) ((x) < (y) ? (y) : (x)) #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) extern "C" __global__ void main_(const float* in, const float* wig, float* out) { // blockIdx = (l, batch, mn) // blockDim = (32, 32, 1) // threadIdx = (sub l, sub batch, 0) // gridDim = (b / 32, nbatch / 32, (2b-1)**2) int m = (blockIdx.z / (2 * B_OUT - 1)) - (B_OUT - 1); int n = (blockIdx.z % (2 * B_OUT - 1)) - (B_OUT - 1); int l_min = MAX(abs(m), abs(n)); if (blockIdx.x * 32 + 31 < l_min) { // for blocks fully out of l-range return; // note: this return does not depend on threadIdx } #ifdef REAL_IN if (n < 0 || (n == 0 && m < 0)) { return; // note: this return does not depend on threadIdx } #endif int batch = blockIdx.y * 32 + threadIdx.y; int l = blockIdx.x * 32 + threadIdx.x; int lmn = (4 * l*l - 1) * l / 3 + (l+m) * (2 * l + 1) + (l+n); float sum_re = 0.0; float sum_im = 0.0; for (int tile = 0; tile < CEIL_DIV(2 * B_IN, 32); ++tile) { __shared__ float tileA[32][32][2]; __shared__ float tileB[32][32]; int beta = tile * 32 + threadIdx.x; #ifdef REAL_IN // `in` shape is (NBATCH, 2 * B_IN, 2 * B_IN, B_IN + 1, 2) // http://www.fftw.org/fftw3_doc/Multi_002dDimensional-DFTs-of-Real-Data.html int i = (((batch * 2*B_IN + beta) * 2*B_IN + MOD(m, 2*B_IN)) * (B_IN + 1) + n) * 2; #else int i = (((batch * 2*B_IN + beta) * 2*B_IN + MOD(m, 2*B_IN)) * 2*B_IN + MOD(n, 2*B_IN)) * 2; #endif tileA[threadIdx.y][threadIdx.x][0] = beta < 2*B_IN && batch < NBATCH ? in[i + 0] : 0.0; tileA[threadIdx.y][threadIdx.x][1] = beta < 2*B_IN && batch < NBATCH ? in[i + 1] : 0.0; beta = tile * 32 + threadIdx.y; tileB[threadIdx.y][threadIdx.x] = beta < 2*B_IN && l_min <= l && l < B_OUT ? wig[beta * NSPEC + lmn] : 0.0; __syncthreads(); for (int beta = 0; beta < 32; ++beta) { sum_re += tileA[threadIdx.y][beta][0] * tileB[beta][threadIdx.x]; sum_im += tileA[threadIdx.y][beta][1] * tileB[beta][threadIdx.x]; } __syncthreads(); } // About this if: some blocks are used to compute but not to save the results if (l_min <= l && l < B_OUT && batch < NBATCH) { out[(lmn * NBATCH + batch) * 2 + 0] = sum_re; out[(lmn * NBATCH + batch) * 2 + 1] = sum_im; #ifdef REAL_IN lmn = (4 * l*l - 1) * l / 3 + (l-m) * (2 * l + 1) + (l-n); float fudge = (m - n) % 2 == 0 ? 1.0 : -1.0; out[(lmn * NBATCH + batch) * 2 + 0] = fudge * sum_re; out[(lmn * NBATCH + batch) * 2 + 1] = -fudge * sum_im; #endif } } ''' import s2cnn.utils.cuda as cuda_utils kernel = cuda_utils.compile_kernel(kernel, 'so3fft.cu', 'main_') stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) def fun(x, wigner, output): assert output.is_contiguous() kernel(block=(32, 32, 1), grid=(math.ceil(b_out / 32), math.ceil(nbatch / 32), (2 * b_out - 1) ** 2), args=[x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr()], stream=stream) return fun
def _setup_so3mm_cuda_kernel(nl, ni, nj, nk, conj_x=False, conj_y=False, trans_x_spec=False, trans_x_feature=False, trans_y_spec=False, trans_y_feature=False, trans_out_feature=False, device=0): ''' return a function that computes out[l*m*n, i, j] = sum_k sum_p x[l*m*p, i, k] y[l*p*n, k, j] where out, x, y are complex valued if conj_x is set to True, x is conjugated if conj_y is set to True, y is conjugated if trans_x_spec is set to True m and p are permuted in x[...] if trans_y_spec is set to True p and n are permuted in y[...] if trans_x_feature is set to True i and k are permuted in x[...] if trans_y_feature is set to True k and j are permuted in y[...] if trans_out_feature is set to True i and j are permuted in out[...] ''' kernel = ''' #define NI {} #define NJ {} #define NK {} '''.format(ni, nj, nk) if not trans_x_spec and not trans_x_feature: kernel += '#define INDEX_X (((L0 + m * L + p) * NI + i) * NK + k)\n' if not trans_x_spec and trans_x_feature: kernel += '#define INDEX_X (((L0 + m * L + p) * NK + k) * NI + i)\n' if trans_x_spec and not trans_x_feature: kernel += '#define INDEX_X (((L0 + p * L + m) * NI + i) * NK + k)\n' if trans_x_spec and trans_x_feature: kernel += '#define INDEX_X (((L0 + p * L + m) * NK + k) * NI + i)\n' if not trans_y_spec and not trans_y_feature: kernel += '#define INDEX_Y (((L0 + p * L + n) * NK + k) * NJ + j)\n' if not trans_y_spec and trans_y_feature: kernel += '#define INDEX_Y (((L0 + p * L + n) * NJ + j) * NK + k)\n' if trans_y_spec and not trans_y_feature: kernel += '#define INDEX_Y (((L0 + n * L + p) * NK + k) * NJ + j)\n' if trans_y_spec and trans_y_feature: kernel += '#define INDEX_Y (((L0 + n * L + p) * NJ + j) * NK + k)\n' if not trans_out_feature: kernel += '#define INDEX_OUT (((L0 + m * L + n) * NI + i) * NJ + j)\n' if trans_out_feature: kernel += '#define INDEX_OUT (((L0 + m * L + n) * NJ + j) * NI + i)\n' kernel += ''' #define CONJ_X {} #define CONJ_Y {} '''.format("x_im = -x_im;" if conj_x else ";", "y_im = -y_im;" if conj_y else ";") kernel += ''' #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) extern "C" __global__ void main_(const float* in_x, const float* in_y, float* out) { // start of thread independant code int l = blockIdx.z; int L = 2 * l + 1; int L0 = (4 * l*l - 1) * l / 3; if (blockIdx.y * 32 >= L * NI || blockIdx.x * 32 >= L * NJ) { return; } int ntile = CEIL_DIV(L * NK, 32); // end of thread independant code int mi = blockIdx.y * 32 + threadIdx.y; int m = mi / NI; int i = mi % NI; int nj = blockIdx.x * 32 + threadIdx.x; int n = nj / NJ; int j = nj % NJ; float sum_re = 0.0; float sum_im = 0.0; for (int tile = 0; tile < ntile; ++tile) { __shared__ float tileX[2][32][32]; __shared__ float tileY[2][32][32]; int pk = tile * 32 + threadIdx.x; int p = pk / NK; int k = pk % NK; int index = INDEX_X * 2; tileX[0][threadIdx.y][threadIdx.x] = m < L && p < L ? in_x[index + 0] : 0.0; tileX[1][threadIdx.y][threadIdx.x] = m < L && p < L ? in_x[index + 1] : 0.0; pk = tile * 32 + threadIdx.y; p = pk / NK; k = pk % NK; index = INDEX_Y * 2; tileY[0][threadIdx.y][threadIdx.x] = p < L && n < L ? in_y[index + 0] : 0.0; tileY[1][threadIdx.y][threadIdx.x] = p < L && n < L ? in_y[index + 1] : 0.0; __syncthreads(); for (int any = 0; any < 32; ++any) { float x_re = tileX[0][threadIdx.y][any]; float x_im = tileX[1][threadIdx.y][any]; float y_re = tileY[0][any][threadIdx.x]; float y_im = tileY[1][any][threadIdx.x]; CONJ_X CONJ_Y sum_re += x_re * y_re - x_im * y_im; sum_im += x_re * y_im + x_im * y_re; } __syncthreads(); } if (m < L && n < L) { int index = INDEX_OUT * 2; out[index + 0] = sum_re; out[index + 1] = sum_im; } } ''' import s2cnn.utils.cuda as cuda_utils kernel = cuda_utils.compile_kernel(kernel, b'so3_mm.cu', 'main_') stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) def fun(x, y, output): assert output.is_contiguous() kernel(block=(32, 32, 1), grid=(math.ceil( (2 * nl - 1) * nj / 32), math.ceil( (2 * nl - 1) * ni / 32), nl), args=[ x.contiguous().data_ptr(), y.contiguous().data_ptr(), output.data_ptr() ], stream=stream) return fun