def nufft_adjoint(input, coord, oshape, oversamp=1.25, width=4.0, n=128, device='cuda'): ndim = coord.shape[-1] beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5 oshape = list(oshape) os_shape = _get_oversamp_shape(oshape, ndim, oversamp) # Gridding coord = _scale_coord(coord, oshape, oversamp, device) kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device) output = interp.gridding(input, os_shape, width, kernel, coord, device) # IFFT output = output.permute(0, 2, 3, 1) # plt.figure() # plt.imshow(print_complex_kspace_tensor(output[0].detach().cpu()), cmap='gray') # plt.show() output = transforms.ifft2(output) # plt.figure() # plt.imshow(print_complex_image_tensor(output[0].detach().cpu()), cmap='gray') # plt.show() # Crop output = output.permute(0, 3, 1, 2) output = util.resize(output, oshape, device=device) output *= util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:]) ** 0.5 # Apodize output = _apodize(output, ndim, oversamp, width, beta, device) return output.permute(0, 2, 3, 1)
def nufft_adjoint(input, coord, out_shape, oversamp=1.25, width=4.0, n=128, device='cuda'): ndim = coord.shape[-1] beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5 out_shape = list(out_shape) os_shape = _get_oversamp_shape(out_shape, ndim, oversamp) # Gridding out_shape2 = out_shape.copy() os_shape2 = os_shape.copy() coord = _scale_coord(coord, out_shape2, oversamp, device) kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device) output = interp.gridding(input, os_shape2, width, kernel, coord, device) # IFFT output = output.permute(0, 1, 3, 4, 2) output = transforms.ifft2_regular(output) output = output.permute(0, 1, 4, 2, 3) # Crop output = util.resize(output, out_shape2, device=device) a = util.prod(os_shape2[-ndim:]) / util.prod(out_shape2[-ndim:]) ** 0.5 output = output * a # Apodize output = _apodize(output, ndim, oversamp, width, beta, device) return output
def gridding(input, shape, width, kernel, coord, device): ndim = coord.shape[-1] batch_shape = shape[:-ndim] batch_size = util.prod(batch_shape) pts_shape = coord.shape[:-1] npts = util.prod(pts_shape) input = input.reshape([batch_size, npts]) coord = coord.reshape([npts, ndim]) output = torch.zeros([batch_size] + list(shape[-ndim:]), dtype=input.dtype, device=device) output = _gridding2(output, input, width, kernel, coord) return output.reshape(shape)
def interpolate(input, width, kernel, coord, device): ndim = coord.shape[-1] batch_shape = input.shape[:-ndim] batch_size = util.prod(batch_shape) pts_shape = coord.shape[:-1] npts = util.prod(pts_shape) input = input.reshape([batch_size] + list(input.shape[-ndim:])) coord = coord.reshape([npts, ndim]) output = torch.zeros([batch_size, npts], dtype=input.dtype, device=device) output = _interpolate2(output, input, width, kernel, coord) return output.reshape(batch_shape + pts_shape)
def nufft(input, coord, ndim=2, oversamp=1.25, width=4.0, n=128, device='cuda'): # ndim = coord.shape[-1] beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5 os_shape = _get_oversamp_shape(input.shape, ndim, oversamp) output = input.clone() # Apodize output = _apodize(output, ndim, oversamp, width, beta, device) # Zero-pad output = output / util.prod(input.shape[-ndim:])**0.5 output = util.resize(output, os_shape, device=device) # FFT if ndim == 2: output = transforms.rfft2(output) elif ndim == 3: output = transforms.rfft3(output) # Interpolate coord = _scale_coord(coord, input.shape, oversamp, device) kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device) output = interp.interpolate(output, width, kernel, coord, ndim, device) return output
def nufft_adjoint(input, coord, oshape, ndim=2, oversamp=1.25, width=4.0, n=128, device='cuda'): # ndim = coord.shape[-1] beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5))**2 - 0.8)**0.5 oshape = list(oshape) os_shape = _get_oversamp_shape(oshape, ndim, oversamp) # Gridding oshape2 = oshape.copy() oshape2[1] = 2 os_shape2 = os_shape.copy() os_shape2[1] = 2 coord = _scale_coord(coord, oshape2, oversamp, device) kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device) output = interp.gridding(input, os_shape2, width, kernel, coord, ndim, device) # IFFT if ndim == 2: output = transforms.ifft2(output) elif ndim == 3: output = transforms.ifft3(output) # Crop output = util.resize(output, oshape2, device=device) a = util.prod(os_shape[-ndim:]) / util.prod(oshape[-ndim:])**0.5 output = output * a # Apodize output = _apodize(output, ndim, oversamp, width, beta, device) if ndim == 2: output = output[:, 0, :, :] elif ndim == 3: output = output[:, 0, :, :, :] return output