Exemple #1
0
def _convolve_spatial2(im,
                       hs,
                       mode="constant",
                       grid_dim=None,
                       pad_factor=2,
                       plan=None,
                       return_plan=False):
    """
    spatial varying convolution of an 2d image with a 2d grid of psfs

    shape(im_ = (Ny,Nx)
    shape(hs) = (Gy,Gx, Hy,Hx)

    the input image im is subdivided into (Gy,Gx) blocks
    hs[j,i] is the psf at the center of each block (i,j)

    as of now each image dimension has to be divisible by the grid dim, i.e.
    Nx % Gx == 0
    Ny % Gy == 0


    mode can be:
    "constant" - assumed values to be zero
    "wrap" - periodic boundary condition
    """

    if grid_dim:
        Gs = tuple(grid_dim)
    else:
        Gs = hs.shape[:2]

    mode_str = {"constant": "CLK_ADDRESS_CLAMP", "wrap": "CLK_ADDRESS_REPEAT"}

    Ny, Nx = im.shape
    Gy, Gx = Gs

    # the size of each block within the grid
    Nblock_y, Nblock_x = Ny // Gy, Nx // Gx

    # the size of the overlapping patches with safety padding
    Npatch_x, Npatch_y = _next_power_of_2(
        pad_factor * Nblock_x), _next_power_of_2(pad_factor * Nblock_y)

    prog = OCLProgram(abspath("kernels/conv_spatial2.cl"),
                      build_options=["-D",
                                     "ADDRESSMODE=%s" % mode_str[mode]])

    if plan is None:
        plan = fft_plan((Gy, Gx, Npatch_y, Npatch_x), axes=(-2, -1))

    x0s = Nblock_x * np.arange(Gx)
    y0s = Nblock_y * np.arange(Gy)

    patches_g = OCLArray.empty((Gy, Gx, Npatch_y, Npatch_x), np.complex64)

    #prepare psfs
    if grid_dim:
        h_g = OCLArray.zeros((Gy, Gx, Npatch_y, Npatch_x), np.complex64)
        tmp_g = OCLArray.from_array(hs.astype(np.float32, copy=False))
        for i, _x0 in enumerate(x0s):
            for j, _y0 in enumerate(y0s):
                prog.run_kernel(
                    "fill_psf_grid2", (Nblock_x, Nblock_y), None, tmp_g.data,
                    np.int32(Nx),
                    np.int32(i * Nblock_x), np.int32(j * Nblock_y), h_g.data,
                    np.int32(Npatch_x), np.int32(Npatch_y),
                    np.int32(-Nblock_x // 2 + Npatch_x // 2),
                    np.int32(-Nblock_y // 2 + Npatch_y // 2),
                    np.int32(i * Npatch_x * Npatch_y +
                             j * Gx * Npatch_x * Npatch_y))
    else:
        hs = np.fft.fftshift(pad_to_shape(hs, (Gy, Gx, Npatch_y, Npatch_x)),
                             axes=(2, 3))
        h_g = OCLArray.from_array(hs.astype(np.complex64))

    #prepare image
    im_g = OCLImage.from_array(im.astype(np.float32, copy=False))

    for i, _x0 in enumerate(x0s):
        for j, _y0 in enumerate(y0s):
            prog.run_kernel(
                "fill_patch2", (Npatch_x, Npatch_y), None, im_g,
                np.int32(_x0 + Nblock_x // 2 - Npatch_x // 2),
                np.int32(_y0 + Nblock_y // 2 - Npatch_y // 2), patches_g.data,
                np.int32(i * Npatch_x * Npatch_y +
                         j * Gx * Npatch_x * Npatch_y))

    #return np.abs(patches_g.get())
    # convolution
    fft(patches_g, inplace=True, plan=plan)
    fft(h_g, inplace=True, plan=plan)
    prog.run_kernel("mult_inplace", (Npatch_x * Npatch_y * Gx * Gy, ), None,
                    patches_g.data, h_g.data)
    fft(patches_g, inplace=True, inverse=True, plan=plan)

    logger.debug("Nblock_x: {}, Npatch_x: {}".format(Nblock_x, Npatch_x))
    #return np.abs(patches_g.get())
    #accumulate
    res_g = OCLArray.empty(im.shape, np.float32)

    for j in range(Gy + 1):
        for i in range(Gx + 1):
            prog.run_kernel("interpolate2", (Nblock_x, Nblock_y),
                            None, patches_g.data, res_g.data, np.int32(i),
                            np.int32(j), np.int32(Gx), np.int32(Gy),
                            np.int32(Npatch_x), np.int32(Npatch_y))

    res = res_g.get()

    if return_plan:
        return res, plan
    else:
        return res
def convolve_spatial2(im, hs, mode="constant", plan=None, return_plan=False):
    """
    spatial varying convolution of an 2d image with a 2d grid of psfs

    shape(im_ = (Ny,Nx)
    shape(hs) = (Gy,Gx, Hy,Hx)

    the input image im is subdivided into (Gy,Gz) blocks
    hs[j,i] is the psf at the center of each block (i,j)

    as of now each image dimension has to be divisble by the grid dim, i.e.
    Nx % Gx == 0
    Ny % Gy == 0

    mode can be:
    "constant" - assumed values to be zero
    "wrap" - periodic boundary condition
    """

    if im.ndim != 2 or hs.ndim != 4:
        raise ValueError("wrong dimensions of input!")

    if not np.all([n % g == 0 for n, g in zip(im.shape, hs.shape[:2])]):
        raise NotImplementedError(
            "shape of image has to be divisible by Gx Gy  = %s shape mismatch"
            % (str(hs.shape[:2])))

    mode_str = {"constant": "CLK_ADDRESS_CLAMP", "wrap": "CLK_ADDRESS_REPEAT"}

    Ny, Nx = im.shape
    Gy, Gx = hs.shape[:2]

    # the size of each block within the grid
    Nblock_y, Nblock_x = Ny / Gy, Nx / Gx

    # the size of the overlapping patches with safety padding
    Npatch_x, Npatch_y = _next_power_of_2(3 * Nblock_x), _next_power_of_2(
        3 * Nblock_y)
    #Npatch_x, Npatch_y = _next_power_of_2(2*Nblock_x), _next_power_of_2(2*Nblock_y)

    print(Nblock_x, Npatch_x)

    hs = np.fft.fftshift(pad_to_shape(hs, (Gy, Gx, Npatch_y, Npatch_x)),
                         axes=(2, 3))

    prog = OCLProgram(abspath("kernels/conv_spatial.cl"),
                      build_options=["-D",
                                     "ADDRESSMODE=%s" % mode_str[mode]])

    if plan is None:
        plan = fft_plan((Npatch_y, Npatch_x))

    patches_g = OCLArray.empty((Gy, Gx, Npatch_y, Npatch_x), np.complex64)

    h_g = OCLArray.from_array(hs.astype(np.complex64))

    im_g = OCLImage.from_array(im.astype(np.float32, copy=False))

    x0s = Nblock_x * np.arange(Gx)
    y0s = Nblock_y * np.arange(Gy)

    print(x0s)

    for i, _x0 in enumerate(x0s):
        for j, _y0 in enumerate(y0s):
            prog.run_kernel(
                "fill_patch2", (Npatch_x, Npatch_y), None, im_g,
                np.int32(_x0 + Nblock_x / 2 - Npatch_x / 2),
                np.int32(_y0 + Nblock_y / 2 - Npatch_y / 2), patches_g.data,
                np.int32(i * Npatch_x * Npatch_y +
                         j * Gx * Npatch_x * Npatch_y))

    # convolution
    fft(patches_g, inplace=True, batch=Gx * Gy, plan=plan)
    fft(h_g, inplace=True, batch=Gx * Gy, plan=plan)
    prog.run_kernel("mult_inplace", (Npatch_x * Npatch_y * Gx * Gy, ), None,
                    patches_g.data, h_g.data)

    fft(patches_g, inplace=True, inverse=True, batch=Gx * Gy, plan=plan)

    #return patches_g.get()

    #accumulate
    res_g = OCLArray.empty(im.shape, np.float32)

    for i in range(Gx + 1):
        for j in range(Gy + 1):
            prog.run_kernel("interpolate2", (Nblock_x, Nblock_y),
                            None, patches_g.data, res_g.data, np.int32(i),
                            np.int32(j), np.int32(Gx), np.int32(Gy),
                            np.int32(Npatch_x), np.int32(Npatch_y))

    res = res_g.get()

    if return_plan:
        return res, plan
    else:
        return res
def convolve_spatial3(im,
                      hs,
                      mode="constant",
                      plan=None,
                      return_plan=False,
                      pad_factor=2):
    """
    spatial varying convolution of an 3d image with a 3d grid of psfs

    shape(im_ = (Nz,Ny,Nx)
    shape(hs) = (Gz,Gy,Gx, Hz,Hy,Hx)

    the input image im is subdivided into (Gx,Gy,Gz) blocks
    hs[k,j,i] is the psf at the center of each block (i,j,k)

    as of now each image dimension has to be divisble by the grid dim, i.e.
    Nx % Gx == 0
    Ny % Gy == 0
    Nz % Gz == 0

    mode can be:
    "constant" - assumed values to be zero
    "wrap" - periodic boundary condition


    """
    if im.ndim != 3 or hs.ndim != 6:
        raise ValueError("wrong dimensions of input!")

    if not np.all([n % g == 0 for n, g in zip(im.shape, hs.shape[:3])]):
        raise NotImplementedError(
            "shape of image has to be divisible by Gx Gy  = %s !" %
            (str(hs.shape[:3])))

    mode_str = {"constant": "CLK_ADDRESS_CLAMP", "wrap": "CLK_ADDRESS_REPEAT"}

    Ns = tuple(im.shape)
    Gs = tuple(hs.shape[:3])

    # the size of each block within the grid
    Nblocks = [n / g for n, g in zip(Ns, Gs)]

    # the size of the overlapping patches with safety padding
    Npatchs = tuple([_next_power_of_2(pad_factor * nb) for nb in Nblocks])

    print(hs.shape)
    hs = np.fft.fftshift(pad_to_shape(hs, Gs + Npatchs), axes=(3, 4, 5))

    prog = OCLProgram(abspath("kernels/conv_spatial.cl"),
                      build_options=["-D",
                                     "ADDRESSMODE=%s" % mode_str[mode]])

    if plan is None:
        plan = fft_plan(Npatchs)

    patches_g = OCLArray.empty(Gs + Npatchs, np.complex64)

    h_g = OCLArray.from_array(hs.astype(np.complex64))

    im_g = OCLImage.from_array(im.astype(np.float32, copy=False))

    Xs = [nb * np.arange(g) for nb, g in zip(Nblocks, Gs)]

    print(Nblocks)
    # this loops over all i,j,k
    for (k, _z0), (j, _y0), (i, _x0) in product(*[enumerate(X) for X in Xs]):
        prog.run_kernel(
            "fill_patch3", Npatchs[::-1], None, im_g,
            np.int32(_x0 + Nblocks[2] / 2 - Npatchs[2] / 2),
            np.int32(_y0 + Nblocks[1] / 2 - Npatchs[1] / 2),
            np.int32(_z0 + Nblocks[0] / 2 - Npatchs[0] / 2), patches_g.data,
            np.int32(i * np.prod(Npatchs) + j * Gs[2] * np.prod(Npatchs) +
                     k * Gs[2] * Gs[1] * np.prod(Npatchs)))

    print(patches_g.shape, h_g.shape)

    # convolution
    fft(patches_g, inplace=True, batch=np.prod(Gs), plan=plan)
    fft(h_g, inplace=True, batch=np.prod(Gs), plan=plan)
    prog.run_kernel("mult_inplace", (np.prod(Npatchs) * np.prod(Gs), ), None,
                    patches_g.data, h_g.data)

    fft(patches_g, inplace=True, inverse=True, batch=np.prod(Gs), plan=plan)

    #return patches_g.get()
    #accumulate
    res_g = OCLArray.zeros(im.shape, np.float32)

    for k, j, i in product(*[list(range(g + 1)) for g in Gs]):
        prog.run_kernel("interpolate3", Nblocks[::-1], None, patches_g.data,
                        res_g.data, np.int32(i), np.int32(j), np.int32(k),
                        np.int32(Gs[2]), np.int32(Gs[1]), np.int32(Gs[0]),
                        np.int32(Npatchs[2]), np.int32(Npatchs[1]),
                        np.int32(Npatchs[0]))

    res = res_g.get()

    if return_plan:
        return res, plan
    else:
        return res
def _convolve_spatial2(im, hs,
                      mode = "constant",
                      grid_dim = None,
                      pad_factor = 2,
                      plan = None,
                      return_plan = False):
    """
    spatial varying convolution of an 2d image with a 2d grid of psfs

    shape(im_ = (Ny,Nx)
    shape(hs) = (Gy,Gx, Hy,Hx)

    the input image im is subdivided into (Gy,Gx) blocks
    hs[j,i] is the psf at the center of each block (i,j)

    as of now each image dimension has to be divisible by the grid dim, i.e.
    Nx % Gx == 0
    Ny % Gy == 0


    mode can be:
    "constant" - assumed values to be zero
    "wrap" - periodic boundary condition
    """

    if grid_dim:
        Gs = tuple(grid_dim)
    else:
        Gs = hs.shape[:2]


    mode_str = {"constant":"CLK_ADDRESS_CLAMP",
                "wrap":"CLK_ADDRESS_REPEAT"}

    Ny, Nx = im.shape
    Gy, Gx = Gs


    # the size of each block within the grid
    Nblock_y, Nblock_x = Ny/Gy, Nx/Gx


    # the size of the overlapping patches with safety padding
    Npatch_x, Npatch_y = _next_power_of_2(pad_factor*Nblock_x), _next_power_of_2(pad_factor*Nblock_y)


    prog = OCLProgram(abspath("kernels/conv_spatial2.cl"),
                      build_options=["-D","ADDRESSMODE=%s"%mode_str[mode]])

    if plan is None:
        plan = fft_plan((Npatch_y,Npatch_x))

    x0s = Nblock_x*np.arange(Gx)
    y0s = Nblock_y*np.arange(Gy)


    patches_g = OCLArray.empty((Gy,Gx,Npatch_y,Npatch_x),np.complex64)

    #prepare psfs
    if grid_dim:
        h_g = OCLArray.zeros((Gy,Gx,Npatch_y,Npatch_x),np.complex64)
        tmp_g = OCLArray.from_array(hs.astype(np.float32, copy = False))
        for i,_x0 in enumerate(x0s):
            for j,_y0 in enumerate(y0s):
                prog.run_kernel("fill_psf_grid2",
                                (Nblock_x,Nblock_y),None,
                        tmp_g.data,
                        np.int32(Nx),
                        np.int32(i*Nblock_x),
                        np.int32(j*Nblock_y),
                        h_g.data,
                        np.int32(Npatch_x),
                        np.int32(Npatch_y),
                        np.int32(-Nblock_x/2+Npatch_x/2),
                        np.int32(-Nblock_y/2+Npatch_y/2),
                        np.int32(i*Npatch_x*Npatch_y+j*Gx*Npatch_x*Npatch_y)
                            )
    else:
        hs = np.fft.fftshift(pad_to_shape(hs,(Gy,Gx,Npatch_y,Npatch_x)),axes=(2,3))
        h_g = OCLArray.from_array(hs.astype(np.complex64))


    #prepare image
    im_g = OCLImage.from_array(im.astype(np.float32,copy=False))

    for i,_x0 in enumerate(x0s):
        for j,_y0 in enumerate(y0s):
            prog.run_kernel("fill_patch2",(Npatch_x,Npatch_y),None,
                    im_g,
                    np.int32(_x0+Nblock_x/2-Npatch_x/2),
                    np.int32(_y0+Nblock_y/2-Npatch_y/2),
                    patches_g.data,
                    np.int32(i*Npatch_x*Npatch_y+j*Gx*Npatch_x*Npatch_y))


    #return np.abs(patches_g.get())
    # convolution
    fft(patches_g,inplace=True, batch = Gx*Gy, plan = plan)
    fft(h_g,inplace=True, batch = Gx*Gy, plan = plan)
    prog.run_kernel("mult_inplace",(Npatch_x*Npatch_y*Gx*Gy,),None,
                    patches_g.data, h_g.data)
    fft(patches_g,inplace=True, inverse = True, batch = Gx*Gy, plan = plan)


    print Nblock_x, Npatch_x
    #return np.abs(patches_g.get())
    #accumulate
    res_g = OCLArray.empty(im.shape,np.float32)

    for j in xrange(Gy+1):
        for i in xrange(Gx+1):
            prog.run_kernel("interpolate2",(Nblock_x,Nblock_y),None,
                            patches_g.data,res_g.data,
                            np.int32(i),np.int32(j),
                            np.int32(Gx),np.int32(Gy),
                            np.int32(Npatch_x),np.int32(Npatch_y))

    res = res_g.get()

    if return_plan:
        return res, plan
    else:
        return res
def _convolve_spatial3(im, hs,
                      mode = "constant",
                      grid_dim = None,
                      plan = None,
                      return_plan = False,
                      pad_factor = 2):



    if im.ndim !=3:
        raise ValueError("wrong dimensions of input!")

    if not (hs.ndim==6 or (hs.ndim==3 and grid_dim)):
        raise ValueError("wrong dimensions of psf grid!")

    if grid_dim:
        if hs.shape != im.shape:
            raise ValueError("if grid_dim is set, then im.shape = hs.shape !")
        Gs = tuple(grid_dim)
    else:
        if not hs.ndim==6:
            raise ValueError("wrong dimensions of psf grid! (Gy,Gx,Ny,Nx)")
        Gs = hs.shape[:3]

    if not np.all([n%g==0 for n,g in zip(im.shape,Gs)]):
        raise NotImplementedError("shape of image has to be divisible by Gx Gy  = %s shape mismatch"%(str(hs.shape[:2])))



    mode_str = {"constant":"CLK_ADDRESS_CLAMP",
                "wrap":"CLK_ADDRESS_REPEAT"}

    Ns = im.shape


    # the size of each block within the grid
    Nblocks = [n/g for n,g  in zip(Ns,Gs)]


    # the size of the overlapping patches with safety padding
    Npatchs = tuple([_next_power_of_2(pad_factor*nb) for nb in Nblocks])

    prog = OCLProgram(abspath("kernels/conv_spatial3.cl"),
                      build_options=["-D","ADDRESSMODE=%s"%mode_str[mode]])

    if plan is None:
        plan = fft_plan(Npatchs)


    Xs = [nb*np.arange(g) for nb, g in zip(Nblocks,Gs)]

    patches_g = OCLArray.empty(Gs+Npatchs,np.complex64)

    #prepare psfs
    if grid_dim:
        h_g = OCLArray.zeros(Gs+Npatchs,np.complex64)
        tmp_g = OCLArray.from_array(hs.astype(np.float32, copy = False))
        for (k,_z0), (j,_y0),(i,_x0) in product(*[enumerate(X) for X in Xs]):
            prog.run_kernel("fill_psf_grid3",
                        Nblocks[::-1],None,
                        tmp_g.data,
                        np.int32(im.shape[2]),
                        np.int32(im.shape[1]),
                        np.int32(i*Nblocks[2]),
                        np.int32(j*Nblocks[1]),
                        np.int32(k*Nblocks[0]),
                        h_g.data,
                        np.int32(Npatchs[2]),
                        np.int32(Npatchs[1]),
                        np.int32(Npatchs[0]),
                        np.int32(-Nblocks[2]/2+Npatchs[2]/2),
                        np.int32(-Nblocks[1]/2+Npatchs[1]/2),
                        np.int32(-Nblocks[0]/2+Npatchs[0]/2),
                        np.int32(i*np.prod(Npatchs)+
                         j*Gs[2]*np.prod(Npatchs)+
                         k*Gs[2]*Gs[1]*np.prod(Npatchs)))

    else:
        hs = np.fft.fftshift(pad_to_shape(hs,Gs+Npatchs),axes=(3,4,5))
        h_g = OCLArray.from_array(hs.astype(np.complex64))


    im_g = OCLImage.from_array(im.astype(np.float32,copy=False))

    # this loops over all i,j,k
    for (k,_z0), (j,_y0),(i,_x0) in product(*[enumerate(X) for X in Xs]):
        prog.run_kernel("fill_patch3",Npatchs[::-1],None,
                im_g,
                    np.int32(_x0+Nblocks[2]/2-Npatchs[2]/2),
                    np.int32(_y0+Nblocks[1]/2-Npatchs[1]/2),
                    np.int32(_z0+Nblocks[0]/2-Npatchs[0]/2),
                    patches_g.data,
                    np.int32(i*np.prod(Npatchs)+
                             j*Gs[2]*np.prod(Npatchs)+
                             k*Gs[2]*Gs[1]*np.prod(Npatchs)))


    # convolution
    fft(patches_g,inplace=True, batch = np.prod(Gs), plan = plan)
    fft(h_g,inplace=True, batch = np.prod(Gs), plan = plan)
    prog.run_kernel("mult_inplace",(np.prod(Npatchs)*np.prod(Gs),),None,
                    patches_g.data, h_g.data)

    fft(patches_g,
        inplace=True,
        inverse = True,
        batch = np.prod(Gs),
        plan = plan)

    #return patches_g.get()
    #accumulate
    res_g = OCLArray.zeros(im.shape,np.float32)

    for k, j, i in product(*[range(g+1) for g in Gs]):
        prog.run_kernel("interpolate3",Nblocks[::-1],None,
                        patches_g.data,
                        res_g.data,
                        np.int32(i),np.int32(j),np.int32(k),
                        np.int32(Gs[2]),np.int32(Gs[1]),np.int32(Gs[0]),
                        np.int32(Npatchs[2]),np.int32(Npatchs[1]),np.int32(Npatchs[0]))


    res = res_g.get()

    if return_plan:
        return res, plan
    else:
        return res
Exemple #6
0
def _convolve_spatial3(im,
                       hs,
                       mode="constant",
                       grid_dim=None,
                       plan=None,
                       return_plan=False,
                       pad_factor=2):
    if im.ndim != 3:
        raise ValueError("wrong dimensions of input!")

    if not (hs.ndim == 6 or (hs.ndim == 3 and grid_dim)):
        raise ValueError("wrong dimensions of psf grid!")

    if grid_dim:
        if hs.shape != im.shape:
            raise ValueError("if grid_dim is set, then im.shape = hs.shape !")
        Gs = tuple(grid_dim)
    else:
        if not hs.ndim == 6:
            raise ValueError("wrong dimensions of psf grid! (Gy,Gx,Ny,Nx)")
        Gs = hs.shape[:3]

    if not np.all([n % g == 0 for n, g in zip(im.shape, Gs)]):
        raise NotImplementedError(
            "shape of image has to be divisible by Gx Gy  = %s shape mismatch"
            % (str(hs.shape[:2])))

    mode_str = {"constant": "CLK_ADDRESS_CLAMP", "wrap": "CLK_ADDRESS_REPEAT"}

    Ns = im.shape

    # the size of each block within the grid
    Nblocks = [n // g for n, g in zip(Ns, Gs)]

    # the size of the overlapping patches with safety padding
    Npatchs = tuple([_next_power_of_2(pad_factor * nb) for nb in Nblocks])

    prog = OCLProgram(abspath("kernels/conv_spatial3.cl"),
                      build_options=["-D",
                                     "ADDRESSMODE=%s" % mode_str[mode]])

    if plan is None:
        plan = fft_plan(Gs + Npatchs, axes=(-3, -2, -1))

    Xs = [nb * np.arange(g) for nb, g in zip(Nblocks, Gs)]

    patches_g = OCLArray.empty(Gs + Npatchs, np.complex64)

    # prepare psfs
    if grid_dim:
        h_g = OCLArray.zeros(Gs + Npatchs, np.complex64)

        tmp_g = OCLArray.from_array(hs.astype(np.float32, copy=False))
        for (k, _z0), (j, _y0), (i,
                                 _x0) in product(*[enumerate(X) for X in Xs]):

            prog.run_kernel(
                "fill_psf_grid3", Nblocks[::-1], None, tmp_g.data,
                np.int32(im.shape[2]), np.int32(im.shape[1]),
                np.int32(i * Nblocks[2]), np.int32(j * Nblocks[1]),
                np.int32(k * Nblocks[0]), h_g.data, np.int32(Npatchs[2]),
                np.int32(Npatchs[1]), np.int32(Npatchs[0]),
                np.int32(-Nblocks[2] // 2 + Npatchs[2] // 2),
                np.int32(-Nblocks[1] // 2 + Npatchs[1] // 2),
                np.int32(-Nblocks[0] // 2 + Npatchs[0] // 2),
                np.int32(i * np.prod(Npatchs) + j * Gs[2] * np.prod(Npatchs) +
                         k * Gs[2] * Gs[1] * np.prod(Npatchs)))

    else:
        hs = np.fft.fftshift(pad_to_shape(hs, Gs + Npatchs), axes=(3, 4, 5))
        h_g = OCLArray.from_array(hs.astype(np.complex64))

    im_g = OCLImage.from_array(im.astype(np.float32, copy=False))

    # this loops over all i,j,k
    for (k, _z0), (j, _y0), (i, _x0) in product(*[enumerate(X) for X in Xs]):
        prog.run_kernel(
            "fill_patch3", Npatchs[::-1], None, im_g,
            np.int32(_x0 + Nblocks[2] // 2 - Npatchs[2] // 2),
            np.int32(_y0 + Nblocks[1] // 2 - Npatchs[1] // 2),
            np.int32(_z0 + Nblocks[0] // 2 - Npatchs[0] // 2), patches_g.data,
            np.int32(i * np.prod(Npatchs) + j * Gs[2] * np.prod(Npatchs) +
                     k * Gs[2] * Gs[1] * np.prod(Npatchs)))

    # convolution
    fft(patches_g, inplace=True, plan=plan)
    fft(h_g, inplace=True, plan=plan)
    prog.run_kernel("mult_inplace", (np.prod(Npatchs) * np.prod(Gs), ), None,
                    patches_g.data, h_g.data)

    fft(patches_g, inplace=True, inverse=True, plan=plan)

    # return patches_g.get()
    # accumulate
    res_g = OCLArray.zeros(im.shape, np.float32)

    for k, j, i in product(*[list(range(g + 1)) for g in Gs]):
        prog.run_kernel("interpolate3", Nblocks[::-1], None, patches_g.data,
                        res_g.data, np.int32(i), np.int32(j), np.int32(k),
                        np.int32(Gs[2]), np.int32(Gs[1]), np.int32(Gs[0]),
                        np.int32(Npatchs[2]), np.int32(Npatchs[1]),
                        np.int32(Npatchs[0]))

    res = res_g.get()

    if return_plan:
        return res, plan
    else:
        return res
def convolve_spatial3(im, hs,
                      mode = "constant",
                      plan = None,
                      return_plan = False,
                      pad_factor = 2):
    """
    spatial varying convolution of an 3d image with a 3d grid of psfs

    shape(im_ = (Nz,Ny,Nx)
    shape(hs) = (Gz,Gy,Gx, Hz,Hy,Hx)

    the input image im is subdivided into (Gx,Gy,Gz) blocks
    hs[k,j,i] is the psf at the center of each block (i,j,k)

    as of now each image dimension has to be divisble by the grid dim, i.e.
    Nx % Gx == 0
    Ny % Gy == 0
    Nz % Gz == 0

    mode can be:
    "constant" - assumed values to be zero
    "wrap" - periodic boundary condition


    """
    if im.ndim !=3 or hs.ndim !=6:
        raise ValueError("wrong dimensions of input!")

    if not np.all([n%g==0 for n,g in zip(im.shape,hs.shape[:3])]):
        raise NotImplementedError("shape of image has to be divisible by Gx Gy  = %s !"%(str(hs.shape[:3])))


    mode_str = {"constant":"CLK_ADDRESS_CLAMP",
                "wrap":"CLK_ADDRESS_REPEAT"}

    Ns = tuple(im.shape)
    Gs = tuple(hs.shape[:3])


    # the size of each block within the grid
    Nblocks = [n/g for n,g  in zip(Ns,Gs)]


    # the size of the overlapping patches with safety padding
    Npatchs = tuple([_next_power_of_2(pad_factor*nb) for nb in Nblocks])

    print hs.shape
    hs = np.fft.fftshift(pad_to_shape(hs,Gs+Npatchs),axes=(3,4,5))



    prog = OCLProgram(abspath("kernels/conv_spatial.cl"),
                      build_options=["-D","ADDRESSMODE=%s"%mode_str[mode]])

    if plan is None:
        plan = fft_plan(Npatchs)

    patches_g = OCLArray.empty(Gs+Npatchs,np.complex64)

    h_g = OCLArray.from_array(hs.astype(np.complex64))

    im_g = OCLImage.from_array(im.astype(np.float32,copy=False))

    Xs = [nb*np.arange(g) for nb, g in zip(Nblocks,Gs)]




    print Nblocks
    # this loops over all i,j,k
    for (k,_z0), (j,_y0),(i,_x0) in product(*[enumerate(X) for X in Xs]):
        prog.run_kernel("fill_patch3",Npatchs[::-1],None,
                im_g,
                    np.int32(_x0+Nblocks[2]/2-Npatchs[2]/2),
                    np.int32(_y0+Nblocks[1]/2-Npatchs[1]/2),
                    np.int32(_z0+Nblocks[0]/2-Npatchs[0]/2),
                    patches_g.data,
                    np.int32(i*np.prod(Npatchs)+
                             j*Gs[2]*np.prod(Npatchs)+
                             k*Gs[2]*Gs[1]*np.prod(Npatchs)))



    print patches_g.shape, h_g.shape




    # convolution
    fft(patches_g,inplace=True, batch = np.prod(Gs), plan = plan)
    fft(h_g,inplace=True, batch = np.prod(Gs), plan = plan)
    prog.run_kernel("mult_inplace",(np.prod(Npatchs)*np.prod(Gs),),None,
                    patches_g.data, h_g.data)

    fft(patches_g,
        inplace=True,
        inverse = True,
        batch = np.prod(Gs),
        plan = plan)

    #return patches_g.get()
    #accumulate
    res_g = OCLArray.zeros(im.shape,np.float32)

    for k, j, i in product(*[range(g+1) for g in Gs]):
        prog.run_kernel("interpolate3",Nblocks[::-1],None,
                        patches_g.data,
                        res_g.data,
                        np.int32(i),np.int32(j),np.int32(k),
                        np.int32(Gs[2]),np.int32(Gs[1]),np.int32(Gs[0]),
                        np.int32(Npatchs[2]),np.int32(Npatchs[1]),np.int32(Npatchs[0]))


    res = res_g.get()

    if return_plan:
        return res, plan
    else:
        return res
def convolve_spatial2(im, hs,
                      mode = "constant",
                      plan = None,
                      return_plan = False):
    """
    spatial varying convolution of an 2d image with a 2d grid of psfs

    shape(im_ = (Ny,Nx)
    shape(hs) = (Gy,Gx, Hy,Hx)

    the input image im is subdivided into (Gy,Gz) blocks
    hs[j,i] is the psf at the center of each block (i,j)

    as of now each image dimension has to be divisble by the grid dim, i.e.
    Nx % Gx == 0
    Ny % Gy == 0

    mode can be:
    "constant" - assumed values to be zero
    "wrap" - periodic boundary condition
    """

    if im.ndim !=2 or hs.ndim !=4:
        raise ValueError("wrong dimensions of input!")

    if not np.all([n%g==0 for n,g in zip(im.shape,hs.shape[:2])]):
        raise NotImplementedError("shape of image has to be divisible by Gx Gy  = %s shape mismatch"%(str(hs.shape[:2])))


    mode_str = {"constant":"CLK_ADDRESS_CLAMP",
                "wrap":"CLK_ADDRESS_REPEAT"}

    Ny, Nx = im.shape
    Gy, Gx = hs.shape[:2]


    # the size of each block within the grid
    Nblock_y, Nblock_x = Ny/Gy, Nx/Gx


    # the size of the overlapping patches with safety padding
    Npatch_x, Npatch_y = _next_power_of_2(3*Nblock_x), _next_power_of_2(3*Nblock_y)
    #Npatch_x, Npatch_y = _next_power_of_2(2*Nblock_x), _next_power_of_2(2*Nblock_y)

    print Nblock_x, Npatch_x

    hs = np.fft.fftshift(pad_to_shape(hs,(Gy,Gx,Npatch_y,Npatch_x)),axes=(2,3))


    prog = OCLProgram(abspath("kernels/conv_spatial.cl"),
                      build_options=["-D","ADDRESSMODE=%s"%mode_str[mode]])

    if plan is None:
        plan = fft_plan((Npatch_y,Npatch_x))


    patches_g = OCLArray.empty((Gy,Gx,Npatch_y,Npatch_x),np.complex64)

    h_g = OCLArray.from_array(hs.astype(np.complex64))

    im_g = OCLImage.from_array(im.astype(np.float32,copy=False))

    x0s = Nblock_x*np.arange(Gx)
    y0s = Nblock_y*np.arange(Gy)

    print x0s

    for i,_x0 in enumerate(x0s):
        for j,_y0 in enumerate(y0s):
            prog.run_kernel("fill_patch2",(Npatch_x,Npatch_y),None,
                    im_g,
                    np.int32(_x0+Nblock_x/2-Npatch_x/2),
                    np.int32(_y0+Nblock_y/2-Npatch_y/2),
                    patches_g.data,
                    np.int32(i*Npatch_x*Npatch_y+j*Gx*Npatch_x*Npatch_y))

    # convolution
    fft(patches_g,inplace=True, batch = Gx*Gy, plan = plan)
    fft(h_g,inplace=True, batch = Gx*Gy, plan = plan)
    prog.run_kernel("mult_inplace",(Npatch_x*Npatch_y*Gx*Gy,),None,
                    patches_g.data, h_g.data)

    fft(patches_g,inplace=True, inverse = True, batch = Gx*Gy, plan = plan)

    #return patches_g.get()

    #accumulate
    res_g = OCLArray.empty(im.shape,np.float32)

    for i in xrange(Gx+1):
        for j in xrange(Gy+1):
            prog.run_kernel("interpolate2",(Nblock_x,Nblock_y),None,
                            patches_g.data,res_g.data,
                            np.int32(i),np.int32(j),
                            np.int32(Gx),np.int32(Gy),
                            np.int32(Npatch_x),np.int32(Npatch_y))


    res = res_g.get()

    if return_plan:
        return res, plan
    else:
        return res