Esempio n. 1
0
class PyOpenCLNUFFT:
    def __init__(self,
                 ctx,
                 queue,
                 par,
                 kwidth=3,
                 overgridfactor=2,
                 fft_dim=(1, 2),
                 klength=200,
                 DTYPE=np.complex64,
                 DTYPE_real=np.float32):
        print("Setting up PyOpenCL NUFFT.")
        self.DTYPE = DTYPE
        self.DTYPE_real = DTYPE_real
        self.fft_shape = (par["NScan"] * par["NC"] * par["NSlice"], par["N"],
                          par["N"])
        self.traj = par["traj"]
        self.dcf = par["dcf"]
        self.Nproj = par["Nproj"]
        self.ctx = ctx
        self.queue = queue

        self.overgridfactor = overgridfactor
        self.kerneltable, self.kerneltable_FT, self.u = calckbkernel(
            kwidth, overgridfactor, par["N"], klength)
        self.kernelpoints = self.kerneltable.size
        self.fft_scale = DTYPE_real(
            np.sqrt(np.prod(self.fft_shape[fft_dim[0]:])))
        self.deapo = 1 / self.kerneltable_FT.astype(DTYPE_real)
        self.kwidth = kwidth / 2
        self.cl_kerneltable = cl.Buffer(
            self.ctx,
            cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR,
            hostbuf=self.kerneltable.astype(DTYPE_real).data)
        self.deapo_cl = cl.Buffer(self.ctx,
                                  cl.mem_flags.READ_ONLY
                                  | cl.mem_flags.COPY_HOST_PTR,
                                  hostbuf=self.deapo.data)
        self.dcf = clarray.to_device(self.queue, self.dcf)
        self.traj = clarray.to_device(self.queue, self.traj)
        self.tmp_fft_array = (clarray.empty(self.queue, (self.fft_shape),
                                            dtype=DTYPE))
        self.check = np.ones(par["N"], dtype=DTYPE_real)
        self.check[1::2] = -1
        self.check = clarray.to_device(self.queue, self.check)
        self.par_fft = int(self.fft_shape[0] / par["NScan"])
        self.fft = FFT(ctx,
                       queue,
                       self.tmp_fft_array[0:int(self.fft_shape[0] /
                                                par["NScan"]), ...],
                       out_array=self.tmp_fft_array[0:int(self.fft_shape[0] /
                                                          par["NScan"]), ...],
                       axes=fft_dim)
        self.gridsize = par["N"]
        self.fwd_NUFFT = self.NUFFT
        self.adj_NUFFT = self.NUFFTH
        self.prg = Program(
            self.ctx,
            open(
                resource_filename('rrsg_cgreco',
                                  'kernels/opencl_nufft_kernels.c')).read())

    def __del__(self):
        del self.traj
        del self.dcf
        del self.tmp_fft_array
        del self.cl_kerneltable
        del self.fft
        del self.deapo_cl
        del self.check
        del self.queue
        del self.ctx

    def NUFFTH(self, sg, s, wait_for=[]):
        # Zero tmp arrays
        self.tmp_fft_array.add_event(
            self.prg.zero_tmp(self.queue, (self.tmp_fft_array.size, ),
                              None,
                              self.tmp_fft_array.data,
                              wait_for=(s.events + sg.events +
                                        self.tmp_fft_array.events + wait_for)))
        # Grid k-space
        self.tmp_fft_array.add_event(
            self.prg.grid_lut(self.queue, (s.shape[0], s.shape[1] * s.shape[2],
                                           s.shape[-2] * self.gridsize),
                              None,
                              self.tmp_fft_array.data,
                              s.data,
                              self.traj.data,
                              np.int32(self.gridsize),
                              self.DTYPE_real(self.kwidth / self.gridsize),
                              self.dcf.data,
                              self.cl_kerneltable,
                              np.int32(self.kernelpoints),
                              wait_for=(wait_for + sg.events + s.events +
                                        self.tmp_fft_array.events)))

        # FFT
        self.tmp_fft_array.add_event(
            self.prg.fftshift(
                self.queue,
                (self.fft_shape[0], self.fft_shape[1], self.fft_shape[2]),
                None, self.tmp_fft_array.data, self.check.data))
        for j in range(s.shape[0]):
            self.tmp_fft_array.add_event(
                self.fft.enqueue_arrays(
                    data=self.tmp_fft_array[j * self.par_fft:(j + 1) *
                                            self.par_fft, ...],
                    result=self.tmp_fft_array[j * self.par_fft:(j + 1) *
                                              self.par_fft, ...],
                    forward=False)[0])
        self.tmp_fft_array.add_event(
            self.prg.fftshift(
                self.queue,
                (self.fft_shape[0], self.fft_shape[1], self.fft_shape[2]),
                None, self.tmp_fft_array.data, self.check.data))
        return self.prg.deapo_adj(self.queue,
                                  (sg.shape[0] * sg.shape[1] * sg.shape[2],
                                   sg.shape[3], sg.shape[4]),
                                  None,
                                  sg.data,
                                  self.tmp_fft_array.data,
                                  self.deapo_cl,
                                  np.int32(self.tmp_fft_array.shape[-1]),
                                  self.DTYPE_real(self.fft_scale),
                                  self.DTYPE_real(self.overgridfactor),
                                  wait_for=wait_for + sg.events + s.events +
                                  self.tmp_fft_array.events)

    def NUFFT(self, s, sg, wait_for=[]):
        # Zero tmp arrays
        self.tmp_fft_array.add_event(
            self.prg.zero_tmp(self.queue, (self.tmp_fft_array.size, ),
                              None,
                              self.tmp_fft_array.data,
                              wait_for=(s.events + sg.events +
                                        self.tmp_fft_array.events + wait_for)))
        # Deapodization and Scaling
        self.tmp_fft_array.add_event(
            self.prg.deapo_fwd(
                self.queue, (sg.shape[0] * sg.shape[1] * sg.shape[2],
                             sg.shape[3], sg.shape[4]),
                None,
                self.tmp_fft_array.data,
                sg.data,
                self.deapo_cl,
                np.int32(self.tmp_fft_array.shape[-1]),
                self.DTYPE_real(1 / self.fft_scale),
                self.DTYPE_real(self.overgridfactor),
                wait_for=wait_for + sg.events + self.tmp_fft_array.events))
        # FFT
        self.tmp_fft_array.add_event(
            self.prg.fftshift(
                self.queue,
                (self.fft_shape[0], self.fft_shape[1], self.fft_shape[2]),
                None, self.tmp_fft_array.data, self.check.data))
        for j in range(s.shape[0]):
            self.tmp_fft_array.add_event(
                self.fft.enqueue_arrays(
                    data=self.tmp_fft_array[j * self.par_fft:(j + 1) *
                                            self.par_fft, ...],
                    result=self.tmp_fft_array[j * self.par_fft:(j + 1) *
                                              self.par_fft, ...],
                    forward=True)[0])
        self.tmp_fft_array.add_event(
            self.prg.fftshift(
                self.queue,
                (self.fft_shape[0], self.fft_shape[1], self.fft_shape[2]),
                None, self.tmp_fft_array.data, self.check.data))
        # Resample on Spoke
        return self.prg.invgrid_lut(
            self.queue,
            (s.shape[0], s.shape[1] * s.shape[2], s.shape[-2] * self.gridsize),
            None,
            s.data,
            self.tmp_fft_array.data,
            self.traj.data,
            np.int32(self.gridsize),
            self.DTYPE_real(self.kwidth / self.gridsize),
            self.dcf.data,
            self.cl_kerneltable,
            np.int32(self.kernelpoints),
            wait_for=s.events + wait_for + self.tmp_fft_array.events)