def _setup_s2mm_cuda_kernel(nbatch, nspec, nfeature_in, nfeature_out): kernel = Template(''' #define COMPUTE_LMN(s) \ int l = powf(3.0/4.0 * s, 1.0/3.0) - 0.5; \ int L = l * (4 * l * l - 1) / 3; \ int rest = s - L; \ if (rest >= (2 * l + 1) * (2 * l + 1)) { \ ++l; \ L = l * (4 * l * l - 1) / 3; \ rest = s - L; \ } \ int m = rest / (2 * l + 1) - l; \ int n = rest % (2 * l + 1) - l; #define EXTRACT(i1, i2, n2, i3, n3) \ int i1 = index; \ int i3 = i1 % (n3); i1 /= n3; \ int i2 = i1 % (n2); i1 /= n2; #define CONTRACT1(s1, i2, n2, i3, n3) \ ( ( (l * l + (l + (s1))) * (n2) + (i2) ) * (n3) + (i3) ) #define CONTRACT2(s1, s2, i2, n2, i3, n3) \ ( ( (L + (l + (s1)) * (2 * l + 1) + (l + (s2))) * (n2) + (i2) ) * (n3) + (i3) ) extern "C" __global__ void main_(const float* in_x, const float* in_y, float* out) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < ${nspec} * ${nbatch} * ${nfeature_out}; index += blockDim.x * gridDim.x) { EXTRACT(s, i, ${nbatch}, f_out, ${nfeature_out}) // compute s -> (l,m,n) COMPUTE_LMN(s) float out_re = 0.0; float out_im = 0.0; for (int f_in = 0; f_in < ${nfeature_in}; ++f_in) { float x_re = in_x[CONTRACT1(m, i, ${nbatch}, f_in, ${nfeature_in} ) * 2 + 0]; float x_im = in_x[CONTRACT1(m, i, ${nbatch}, f_in, ${nfeature_in} ) * 2 + 1]; float y_re = in_y[CONTRACT1(n, f_in, ${nfeature_in}, f_out, ${nfeature_out}) * 2 + 0]; float y_im = in_y[CONTRACT1(n, f_in, ${nfeature_in}, f_out, ${nfeature_out}) * 2 + 1]; // x times y conjugate out_re += x_re * y_re + x_im * y_im; out_im += x_im * y_re - x_re * y_im; } out[index * 2 + 0] = out_re; out[index * 2 + 1] = out_im; } } ''').substitute({ 'nbatch': nbatch, 'nspec': nspec, 'nfeature_in': nfeature_in, 'nfeature_out': nfeature_out }) return cuda_utils.compile_kernel(kernel, b's2mm.cu', 'main_')
def _setup_s2mm_grady_cuda_kernel(nbatch, nspec, nl, nfeature_in, nfeature_out): kernel = Template(''' #define COMPUTE_LM(s) \ int l = powf(s, 0.5); \ int L = (4 * l * l - 1) * l / 3; \ int m = s - l * l - l; #define EXTRACT(i1, i2, n2, i3, n3) \ int i1 = index; \ int i3 = i1 % (n3); i1 /= n3; \ int i2 = i1 % (n2); i1 /= n2; #define CONTRACT1(s1, i2, n2, i3, n3) \ ( ( (l * l + (l + (s1))) * (n2) + (i2) ) * (n3) + (i3) ) #define CONTRACT2(s1, s2, i2, n2, i3, n3) \ ( ( (L + (l + (s1)) * (2 * l + 1) + (l + (s2))) * (n2) + (i2) ) * (n3) + (i3) ) extern "C" __global__ void main_(const float* grad_z, const float* x, float* grad_y) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (${nl} * ${nl}) * ${nfeature_in} * ${nfeature_out}; index += blockDim.x * gridDim.x) { EXTRACT(s, f_in, ${nfeature_in}, f_out, ${nfeature_out}) // compute s -> (l,m) COMPUTE_LM(s) float out_re = 0.0; float out_im = 0.0; for (int i = 0; i < ${nbatch}; ++i) { for (int k = -l; k <= l; ++k) { float grad_z_re = grad_z[CONTRACT2(k, m, i, ${nbatch}, f_out, ${nfeature_out}) * 2 + 0]; float grad_z_im = grad_z[CONTRACT2(k, m, i, ${nbatch}, f_out, ${nfeature_out}) * 2 + 1]; float x_re = x[CONTRACT1(k, i, ${nbatch}, f_in, ${nfeature_in} ) * 2 + 0]; float x_im = x[CONTRACT1(k, i, ${nbatch}, f_in, ${nfeature_in} ) * 2 + 1]; // conjugate grad_z times x out_re += grad_z_re * x_re + grad_z_im * x_im; out_im += grad_z_re * x_im - grad_z_im * x_re; } } grad_y[index * 2 + 0] = out_re; grad_y[index * 2 + 1] = out_im; } } ''').substitute({ 'nbatch': nbatch, 'nspec': nspec, 'nl': nl, 'nfeature_in': nfeature_in, 'nfeature_out': nfeature_out }) return cuda_utils.compile_kernel(kernel, b's2mm_grady.cu', 'main_')
def _setup_s2ifft_cuda_kernel(b, nl, nbatch): kernel = Template(''' extern "C" __global__ void main_(const float* in, const float* wig, float* out) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < ${nbatch} * 2 * ${b} * 2 * ${b}; index += blockDim.x * gridDim.x) { int i = index / (2 * ${b} * 2 * ${b}); // batch index int beta = (index / (2 * ${b})) % (2 * ${b}); int m = index % (2 * ${b}); // from 0,1,2, 3, 4 or 0,1,2, 3, 4, 5 // to 0,1,2,-2,-1 or 0,1,2,-3,-2,-1 int mm = m <= (2 * ${b} - 1) / 2 ? m : m - 2 * ${b}; float out_re = 0.0; float out_im = 0.0; for (int l = abs(mm); l < ${nl}; ++l) { int s = l * l + (l + mm); float in_re = in[(s * ${nbatch} + i) * 2 + 0]; float in_im = in[(s * ${nbatch} + i) * 2 + 1]; float w = wig[beta * ${nspec} + s]; out_re += in_re * w; out_im += in_im * w; } out[index * 2 + 0] = out_re; out[index * 2 + 1] = out_im; } } ''').substitute({ 'b': b, 'nbatch': nbatch, 'nl': nl, 'nspec': nl**2 }) return cuda_utils.compile_kernel(kernel, b's2ifft.cu', 'main_')
def _setup_s2fft_cuda_kernel(b, nspec, nbatch): kernel = Template(''' #define COMPUTE_LM(s) \ int l = powf(s, 0.5); \ int m = (s - l * l) - l; #define MOD(i, n) (((i) + (n)) % (n)) extern "C" __global__ void main_(const float* in, const float* wig, float* out) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < ${nspec} * ${nbatch}; index += blockDim.x * gridDim.x) { int i = index % ${nbatch}; // batch index int s = index / ${nbatch}; // spectral index // compute s -> (l,m) COMPUTE_LM(s) float out_re = 0.0; float out_im = 0.0; for (int beta = 0; beta < 2 * ${b}; ++beta) { float in_re = in[((i * 2 * ${b} + beta) * 2 * ${b} + MOD(m, 2 * ${b})) * 2 + 0]; float in_im = in[((i * 2 * ${b} + beta) * 2 * ${b} + MOD(m, 2 * ${b})) * 2 + 1]; float w = wig[beta * ${nspec} + s]; out_re += w * in_re; out_im += w * in_im; } out[index * 2 + 0] = out_re; out[index * 2 + 1] = out_im; } } ''').substitute({ 'b': b, 'nbatch': nbatch, 'nspec': nspec }) return cuda_utils.compile_kernel(kernel, b's2fft.cu', 'main_')
def _setup_so3ifft_cuda_kernel(b_in, b_out, nbatch, real_output): kernel = ''' #define B_IN {} #define B_OUT {} #define NSPEC {} #define NBATCH {} '''.format(b_in, b_out, b_in * (4 * b_in**2 - 1) // 3, nbatch) if real_output: kernel += ''' #define REAL_OUT ''' kernel += ''' #define MOD(i, n) (((i) + (n)) % (n)) #define MAX(x, y) ((x) < (y) ? (y) : (x)) #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) extern "C" __global__ void main_(const float* in, const float* wig, float* out) { int m = (blockIdx.z / (2 * B_OUT - 1)) - (B_OUT - 1); int n = (blockIdx.z % (2 * B_OUT - 1)) - (B_OUT - 1); #ifdef REAL_OUT if (n < 0 || (n == 0 && m < 0)) { return; // note: this return does not depend on threadIdx } #endif int l_min = MAX(abs(m), abs(n)); int batch = blockIdx.y * 32 + threadIdx.y; float sum_re = 0.0; float sum_im = 0.0; for (int tile = 0; tile < CEIL_DIV(B_IN - l_min, 32); ++tile) { __shared__ float tileA[2][32][32]; __shared__ float tileB[32][32+1]; int l = l_min + tile * 32 + threadIdx.x; int lmn = (4 * l*l - 1) * l / 3 + (l+m) * (2 * l + 1) + (l+n); int i = (lmn * NBATCH + batch) * 2; tileA[0][threadIdx.y][threadIdx.x] = l < B_IN && batch < NBATCH ? in[i + 0] : 0.0; tileA[1][threadIdx.y][threadIdx.x] = l < B_IN && batch < NBATCH ? in[i + 1] : 0.0; int beta = blockIdx.x * 32 + threadIdx.y; tileB[threadIdx.x][threadIdx.y] = l < B_IN && beta < 2*B_OUT ? wig[beta * NSPEC + lmn] : 0.0; __syncthreads(); for (int l = 0; l < 32; ++l) { sum_re += tileA[0][threadIdx.y][l] * tileB[l][threadIdx.x]; sum_im += tileA[1][threadIdx.y][l] * tileB[l][threadIdx.x]; } __syncthreads(); } int beta = blockIdx.x * 32 + threadIdx.x; if (beta < 2*B_OUT && batch < NBATCH) { int i = (((batch * 2*B_OUT + beta) * 2*B_OUT + MOD(m, 2*B_OUT)) * 2*B_OUT + MOD(n, 2*B_OUT)) * 2; out[i + 0] = sum_re; out[i + 1] = sum_im; #ifdef REAL_OUT i = (((batch * 2*B_OUT + beta) * 2*B_OUT + MOD(-m, 2*B_OUT)) * 2*B_OUT + MOD(-n, 2*B_OUT)) * 2; out[i + 0] = sum_re; out[i + 1] = -sum_im; #endif } } ''' kernel = cuda_utils.compile_kernel(kernel, b'so3ifft.cu', 'main_') stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) def fun(x, wigner, output): output[:] = 0 kernel(block=(32, 32, 1), grid=(math.ceil(2 * b_out / 32), math.ceil(nbatch / 32), (2 * b_out - 1)**2), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) return fun
def _setup_so3fft_cuda_kernel(b_in, b_out, nbatch, real_input): kernel = ''' #define B_IN {} #define B_OUT {} #define NSPEC {} #define NBATCH {} '''.format(b_in, b_out, b_out * (4 * b_out**2 - 1) // 3, nbatch) if real_input: kernel += ''' #define REAL_IN ''' kernel += ''' #define MOD(i, n) (((i) + (n)) % (n)) #define MAX(x, y) ((x) < (y) ? (y) : (x)) #define CEIL_DIV(x, y) (((x) + (y) - 1) / (y)) extern "C" __global__ void main_(const float* in, const float* wig, float* out) { // blockIdx = (l, batch, mn) // blockDim = (32, 32, 1) // threadIdx = (sub l, sub batch, 0) // gridDim = (b / 32, nbatch / 32, (2b-1)**2) int m = (blockIdx.z / (2 * B_OUT - 1)) - (B_OUT - 1); int n = (blockIdx.z % (2 * B_OUT - 1)) - (B_OUT - 1); int l_min = MAX(abs(m), abs(n)); if (blockIdx.x * 32 + 31 < l_min) { // for blocks fully out of l-range return; // note: this return does not depend on threadIdx } #ifdef REAL_IN if (n < 0 || (n == 0 && m < 0)) { return; // note: this return does not depend on threadIdx } #endif int batch = blockIdx.y * 32 + threadIdx.y; int l = blockIdx.x * 32 + threadIdx.x; int lmn = (4 * l*l - 1) * l / 3 + (l+m) * (2 * l + 1) + (l+n); float sum_re = 0.0; float sum_im = 0.0; for (int tile = 0; tile < CEIL_DIV(2 * B_IN, 32); ++tile) { __shared__ float tileA[32][32][2]; __shared__ float tileB[32][32]; int beta = tile * 32 + threadIdx.x; #ifdef REAL_IN // `in` shape is (NBATCH, 2 * B_IN, 2 * B_IN, B_IN + 1, 2) // http://www.fftw.org/fftw3_doc/Multi_002dDimensional-DFTs-of-Real-Data.html int i = (((batch * 2*B_IN + beta) * 2*B_IN + MOD(m, 2*B_IN)) * (B_IN + 1) + n) * 2; #else int i = (((batch * 2*B_IN + beta) * 2*B_IN + MOD(m, 2*B_IN)) * 2*B_IN + MOD(n, 2*B_IN)) * 2; #endif tileA[threadIdx.y][threadIdx.x][0] = beta < 2*B_IN && batch < NBATCH ? in[i + 0] : 0.0; tileA[threadIdx.y][threadIdx.x][1] = beta < 2*B_IN && batch < NBATCH ? in[i + 1] : 0.0; beta = tile * 32 + threadIdx.y; tileB[threadIdx.y][threadIdx.x] = beta < 2*B_IN && l_min <= l && l < B_OUT ? wig[beta * NSPEC + lmn] : 0.0; __syncthreads(); for (int beta = 0; beta < 32; ++beta) { sum_re += tileA[threadIdx.y][beta][0] * tileB[beta][threadIdx.x]; sum_im += tileA[threadIdx.y][beta][1] * tileB[beta][threadIdx.x]; } __syncthreads(); } // About this if: some blocks are used to compute but not to save the results if (l_min <= l && l < B_OUT && batch < NBATCH) { out[(lmn * NBATCH + batch) * 2 + 0] = sum_re; out[(lmn * NBATCH + batch) * 2 + 1] = sum_im; #ifdef REAL_IN lmn = (4 * l*l - 1) * l / 3 + (l-m) * (2 * l + 1) + (l-n); float fudge = (m - n) % 2 == 0 ? 1.0 : -1.0; out[(lmn * NBATCH + batch) * 2 + 0] = fudge * sum_re; out[(lmn * NBATCH + batch) * 2 + 1] = -fudge * sum_im; #endif } } ''' kernel = cuda_utils.compile_kernel(kernel, b'so3fft.cu', 'main_') stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) def fun(x, wigner, output): kernel(block=(32, 32, 1), grid=(math.ceil(b_out / 32), math.ceil(nbatch / 32), (2 * b_out - 1)**2), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) return fun