def __init__(self, ctx, queue, par, kwidth=3, overgridfactor=2, fft_dim=(1, 2), klength=200, DTYPE=np.complex64, DTYPE_real=np.float32): print("Setting up PyOpenCL NUFFT.") self.DTYPE = DTYPE self.DTYPE_real = DTYPE_real self.fft_shape = (par["NScan"] * par["NC"] * par["NSlice"], par["N"], par["N"]) self.traj = par["traj"] self.dcf = par["dcf"] self.Nproj = par["Nproj"] self.ctx = ctx self.queue = queue self.overgridfactor = overgridfactor self.kerneltable, self.kerneltable_FT, self.u = calckbkernel( kwidth, overgridfactor, par["N"], klength) self.kernelpoints = self.kerneltable.size self.fft_scale = DTYPE_real( np.sqrt(np.prod(self.fft_shape[fft_dim[0]:]))) self.deapo = 1 / self.kerneltable_FT.astype(DTYPE_real) self.kwidth = kwidth / 2 self.cl_kerneltable = cl.Buffer( self.ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=self.kerneltable.astype(DTYPE_real).data) self.deapo_cl = cl.Buffer(self.ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=self.deapo.data) self.dcf = clarray.to_device(self.queue, self.dcf) self.traj = clarray.to_device(self.queue, self.traj) self.tmp_fft_array = (clarray.empty(self.queue, (self.fft_shape), dtype=DTYPE)) self.check = np.ones(par["N"], dtype=DTYPE_real) self.check[1::2] = -1 self.check = clarray.to_device(self.queue, self.check) self.par_fft = int(self.fft_shape[0] / par["NScan"]) self.fft = FFT(ctx, queue, self.tmp_fft_array[0:int(self.fft_shape[0] / par["NScan"]), ...], out_array=self.tmp_fft_array[0:int(self.fft_shape[0] / par["NScan"]), ...], axes=fft_dim) self.gridsize = par["N"] self.fwd_NUFFT = self.NUFFT self.adj_NUFFT = self.NUFFTH self.prg = Program( self.ctx, open( resource_filename('rrsg_cgreco', 'kernels/opencl_nufft_kernels.c')).read())
import numpy as np import pyopencl as cl import pyopencl.array as cla from gpyfft.fft import FFT context = cl.create_some_context() queue = cl.CommandQueue(context) data_host = np.zeros((4, 1024, 1024), dtype=np.complex64) #data_host[:] = some_useful_data data_gpu = cla.to_device(queue, data_host) transform = FFT(context, queue, data_gpu, axes=(2, 1)) event, = transform.enqueue() event.wait() result_host = data_gpu.get()