def backward(self, gradz): # pylint: disable=W x, y = self.saved_tensors nl = round(x.size(0)**0.5) nbatch = x.size(1) nfeature_in = x.size(2) nfeature_out = y.size(2) nspec = (4 * nl**2 - 1) * nl // 3 gradx_cuda_kernel = _setup_s2mm_gradx_cuda_kernel(nbatch=nbatch, nspec=nspec, nl=nl, nfeature_in=nfeature_in, nfeature_out=nfeature_out) grady_cuda_kernel = _setup_s2mm_grady_cuda_kernel(nbatch=nbatch, nspec=nspec, nl=nl, nfeature_in=nfeature_in, nfeature_out=nfeature_out) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) gradx = grady = None if self.needs_input_grad[0]: gradx = gradz.new_empty((nl**2, nbatch, nfeature_in, 2)) gradx_cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), grid=(cuda_utils.get_blocks(nl**2 * nbatch * nfeature_in, 1024), 1, 1), args=[gradz.contiguous().data_ptr(), y.contiguous().data_ptr(), gradx.data_ptr()], stream=stream) if self.needs_input_grad[1]: grady = gradz.new_empty((nl**2, nfeature_in, nfeature_out, 2)) grady_cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), grid=(cuda_utils.get_blocks(nl**2 * nfeature_in * nfeature_out, 1024), 1, 1), args=[gradz.contiguous().data_ptr(), x.contiguous().data_ptr(), grady.data_ptr()], stream=stream) return gradx, grady
def _s2_ifft(x, for_grad, b_in, b_out): ''' :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner( b_out, nl=b_in, weighted=for_grad, device_type=x.device.type, device_index=x.device.index) # [beta, l * m] (2 * b_out - 1, nspec) cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2, 1024), 1, 1), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) output = torch.ifft(output, 1) * output.size( -2) # [batch, beta, alpha, complex] return output
def _s2_fft(x, for_grad, b_in, b_out): ''' :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2) :return: [l * m, batch, complex] (b_out**2, nbatch, 2) ''' nspec = b_out**2 nbatch = x.size(0) wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device_type=x.device.type, device_index=x.device.index) cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch) x = torch.fft(x, 1) # [batch, beta, m, complex] stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nspec, nbatch, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1), args=[ x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr() ], stream=stream) # [l * m, batch, complex] return output
def s2_mm(x, y): ''' :param x: [l * m, batch, feature_in, complex] :param y: [l * m, feature_in, feature_out, complex] :return: [l * m * n, batch, feature_out, complex] ''' assert x.is_cuda and x.dtype == torch.float32 assert y.is_cuda and y.dtype == torch.float32 assert y.size(3) == 2 assert x.size(3) == 2 nbatch = x.size(1) nfeature_in = x.size(2) nfeature_out = y.size(2) assert y.size(1) == nfeature_in assert y.size(0) == x.size(0) nl = round(x.size(0)**0.5) nspec = (4 * nl**2 - 1) * nl // 3 assert x.size(0) == nl ** 2 assert y.size(0) == nl ** 2 cuda_kernel = _setup_s2mm_cuda_kernel(nbatch=nbatch, nspec=nspec, nfeature_in=nfeature_in, nfeature_out=nfeature_out) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nspec, nbatch, nfeature_out, 2)) cuda_kernel(block=(cuda_utils.CUDA_NUM_THREADS, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch * nfeature_out, 1024), 1, 1), args=[x.contiguous().data_ptr(), y.contiguous().data_ptr(), output.data_ptr()], stream=stream) # [l * m * n, batch, feature_out, complex] return output
def s2_fft(x, for_grad=False, b_out=None): ''' :param x: [..., beta, alpha, complex] :return: [l * m, ..., complex] ''' assert x.size(-1) == 2 b_in = x.size(-2) // 2 assert x.size(-2) == 2 * b_in assert x.size(-3) == 2 * b_in if b_out is None: b_out = b_in assert b_out <= b_in batch_size = x.size()[:-3] x = x.view(-1, 2 * b_in, 2 * b_in, 2) # [batch, beta, alpha, complex] ''' :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2) :return: [l * m, batch, complex] (b_out**2, nbatch, 2) ''' nspec = b_out**2 nbatch = x.size(0) wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device=x.device) wigner = wigner.view(2 * b_in, -1) # [beta, l * m] (2 * b_in, nspec) x = torch.view_as_real(torch.fft.fft( torch.view_as_complex(x))) # [batch, beta, m, complex] output = x.new_empty((nspec, nbatch, 2)) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=x.device.index) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1), args=[ x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr() ], stream=stream) # [l * m, batch, complex] else: for l in range(b_out): s = slice(l**2, l**2 + 2 * l + 1) xx = torch.cat( (x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1] output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx)) output = output.view(-1, *batch_size, 2) # [l * m, ..., complex] (nspec, ..., 2) return output
def s2_ifft(x, for_grad=False, b_out=None): ''' :param x: [l * m, ..., complex] ''' assert x.size(-1) == 2 nspec = x.size(0) b_in = round(nspec**0.5) assert nspec == b_in**2 if b_out is None: b_out = b_in assert b_out >= b_in batch_size = x.size()[1:-1] x = x.view(nspec, -1, 2) # [l * m, batch, complex] (nspec, nbatch, 2) ''' :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device) wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2, 1024), 1, 1), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) else: output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2)) for l in range(b_in): s = slice(l**2, l**2 + 2 * l + 1) out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s])) output[:, :, :l + 1] += out[:, :, -l - 1:] if l > 0: output[:, :, -l:] += out[:, :, :l] output = torch.view_as_real(torch.fft.ifft( torch.view_as_complex(output))) * output.size( -2) # [batch, beta, alpha, complex] output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2) return output
def _s2_fft(x, for_grad, b_in, b_out): ''' :param x: [batch, beta, alpha, complex] (nbatch, 2 * b_in, 2 * b_in, 2) :return: [l * m, batch, complex] (b_out**2, nbatch, 2) ''' nspec = b_out**2 nbatch = x.size(0) wigner = _setup_wigner(b_in, nl=b_out, weighted=not for_grad, device_type=x.device.type, device_index=x.device.index) wigner = wigner.view(2 * b_in, -1) # [beta, l * m] (2 * b_in, nspec) x = torch.fft(x, 1) # [batch, beta, m, complex] output = x.new_empty((nspec, nbatch, 2)) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils device = torch.cuda.current_device() cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=device) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1), args=[ x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr() ], stream=stream) # [l * m, batch, complex] else: for l in range(b_out): s = slice(l**2, l**2 + 2 * l + 1) xx = torch.cat( (x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1] output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx)) return output
def _s2_ifft(x, for_grad, b_in, b_out): ''' :param x: [l * m, batch, complex] (b_in**2, nbatch, 2) :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2) ''' nbatch = x.size(1) wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device_type=x.device.type, device_index=x.device.index) wigner = wigner.view(2 * b_out, -1) # [beta, l * m] (2 * b_out, nspec) if x.is_cuda and x.dtype == torch.float32: import s2cnn.utils.cuda as cuda_utils device = torch.cuda.current_device() cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=device) stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream) output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2)) cuda_kernel(block=(1024, 1, 1), grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2, 1024), 1, 1), args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()], stream=stream) # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2) else: output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2)) for l in range(b_in): s = slice(l**2, l**2 + 2 * l + 1) out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s])) output[:, :, :l + 1] += out[:, :, -l - 1:] if l > 0: output[:, :, -l:] += out[:, :, :l] output = torch.ifft(output, 1) * output.size( -2) # [batch, beta, alpha, complex] return output