예제 #1
0
 def _rmatvec(self, x):
     # correct type of h if different from x and choose methods accordingly
     if type(self.h) != type(x):
         self.h = to_cupy_conditional(x, self.h)
         self.convolve = get_convolve(self.h)
         self.correlate = get_correlate(self.h)
     x = np.reshape(x, self.dims)
     y = self.correlate(x, self.h, mode='same', method=self.method)
     y = y.ravel()
     return y
예제 #2
0
    def __init__(self,
                 N,
                 h,
                 dims,
                 offset=None,
                 dirs=None,
                 method='fft',
                 dtype='float64'):
        ncp = get_array_module(h)
        self.h = h
        self.nh = np.array(self.h.shape)
        self.dirs = np.arange(len(dims)) if dirs is None else np.array(dirs)

        # padding
        if offset is None:
            offset = np.zeros(self.h.ndim, dtype=np.int)
        else:
            offset = np.array(offset, dtype=np.int)
        self.offset = 2 * (self.nh // 2 - offset)
        pad = [(0, 0) for _ in range(self.h.ndim)]
        dopad = False
        for inh, nh in enumerate(self.nh):
            if nh % 2 == 0:
                self.offset[inh] -= 1
            if self.offset[inh] != 0:
                pad[inh] = [
                    self.offset[inh] if self.offset[inh] > 0 else 0,
                    -self.offset[inh] if self.offset[inh] < 0 else 0
                ]
                dopad = True
        if dopad:
            self.h = ncp.pad(self.h, pad, mode='constant')
        self.nh = self.h.shape

        # find out which directions are used for convolution and define offsets
        if len(dims) != len(self.nh):
            dimsh = np.ones(len(dims), dtype=np.int)
            for idir, dir in enumerate(self.dirs):
                dimsh[dir] = self.nh[idir]
            self.h = self.h.reshape(dimsh)

        if np.prod(dims) != N:
            raise ValueError('product of dims must equal N!')
        else:
            self.dims = np.array(dims)
            self.reshape = True

        # convolve and correate functions
        self.convolve = get_convolve(h)
        self.correlate = get_correlate(h)
        self.method = method

        self.shape = (np.prod(self.dims), np.prod(self.dims))
        self.dtype = np.dtype(dtype)
        self.explicit = False