Esempio n. 1
0
    def forward(self, x, grid, **overload):
        """

        Parameters
        ----------
        x : (batch, channel, *spatial_in) tensor
            Input image to deform
        grid : (batch, *spatial_in, dir) tensor
            Transformation grid
        overload : dict
            All parameters defined at build time can be overridden
            at call time.

        Returns
        -------
        pushed : (batch, channel, *shape) tensor
            Pushed image.
        count : (batch, 1, *shape) tensor
            Pushed image.

        """
        shape = overload.get('shape', self.shape)
        interpolation = overload.get('interpolation', self.interpolation)
        bound = overload.get('bound', self.bound)
        extrapolate = overload.get('extrapolate', self.extrapolate)
        push = spatial.grid_push(x, grid, shape,
                                 interpolation=interpolation,
                                 bound=bound,
                                 extrapolate=extrapolate)
        count = spatial.grid_count(grid, shape,
                                   interpolation=interpolation,
                                   bound=bound,
                                   extrapolate=extrapolate)
        return push, count
Esempio n. 2
0
    def forward(self, x, grid, shape=None):
        """

        Parameters
        ----------
        x : (batch, channel, *spatial_in) tensor
            Input image to deform
        grid : (batch, *spatial_in, dir) tensor
            Transformation grid
        shape : list[int], default=self.shape
            Output spatial shape. Default is the same as the input shape.

        Returns
        -------
        pushed : (batch, channel, *shape) tensor
            Pushed image.
        count : (batch, 1, *shape) tensor
            Pushed image.

        """
        shape = shape or self.shape
        push = spatial.grid_push(x,
                                 grid,
                                 shape,
                                 interpolation=self.interpolation,
                                 bound=self.bound,
                                 extrapolate=self.extrapolate)
        count = spatial.grid_count(grid,
                                   shape,
                                   interpolation=self.interpolation,
                                   bound=self.bound,
                                   extrapolate=self.extrapolate)
        return push, count
Esempio n. 3
0
    def forward(self, x, grid, shape=None):
        """

        Parameters
        ----------
        x : (batch, channel, *spatial_in) tensor
            Input image to deform
        grid : (batch, *spatial_out, len(spatial_in)) tensor
            Transformation grid
        shape : list[int], default=self.shape
            Output spatial shape.

        Returns
        -------
        pushed : (batch, channel, *spatial_out) tensor
            Deformed image.

        """
        shape = shape or self.shape
        return spatial.grid_push(x,
                                 grid,
                                 shape,
                                 interpolation=self.interpolation,
                                 bound=self.bound,
                                 extrapolate=self.extrapolate)
Esempio n. 4
0
def push1d(img, grid, dim, **kwargs):
    if grid is None:
        return img
    kwargs.setdefault('extrapolate', True)
    kwargs.setdefault('bound', 'dft')
    img = core.utils.movedim(img, dim, -1).unsqueeze(-2)
    grid = core.utils.movedim(grid, dim, -1).unsqueeze(-1)
    pushed = spatial.grid_push(img, grid, **kwargs)
    pushed = core.utils.movedim(pushed.squeeze(-2), -1, dim)
    return pushed
Esempio n. 5
0
    def push_radius(self, x, t):
        """Push gradient into the radius control points
        (= differentiate wrt radius control points)"""
        t = t.flatten()
        t = t.clamp(0, 1) * (len(self.coeff_radius) - 1)

        x = x.flatten()                   # [N]
        t = t.unsqueeze(-1)               # [N, 1]
        y = grid_push(x, t, [len(self.coeff_radius)],
                      bound=self.bound, interpolation=self.order)
        return y
Esempio n. 6
0
    def push_position(self, x, t):
        """Push gradient into the control points
        (= differentiate wrt control points)"""
        t = t.flatten()
        t = t.clamp(0, 1) * (len(self.coeff) - 1)

        x = x.reshape(-1, x.shape[-1]).T  # [D, N]
        t = t.unsqueeze(-1)               # [N, 1]
        y = grid_push(x, t, [len(self.coeff)],
                      bound=self.bound, interpolation=self.order)
        y = y.T                           # [K, D]
        return y
Esempio n. 7
0
def smart_push_grid(vel, grid, *args, **kwargs):
    """Push a velocity/grid/displacement field.

    Notes
    -----
    Defaults differ from grid_push:
    - bound -> dft
    - extrapolate -> True

    Parameters
    ----------
    vel : ([batch], *spatial, ndim) tensor
        Velocity
    grid : ([batch], *spatial, ndim) tensor
        Transformation field
    kwargs : dict
        Options to ``grid_pull``

    Returns
    -------
    pulled_vel : ([batch], *spatial, ndim) tensor
        Velocity

    """
    if grid is None or vel is None:
        return vel
    kwargs.setdefault('bound', 'dft')
    kwargs.setdefault('extrapolate', True)
    dim = vel.shape[-1]
    vel = utils.movedim(vel, -1, -dim - 1)
    vel_no_batch = vel.dim() == dim + 1
    grid_no_batch = grid.dim() == dim + 1
    if vel_no_batch:
        vel = vel[None]
    if grid_no_batch:
        grid = grid[None]
    vel = spatial.grid_push(vel, grid, *args, **kwargs)
    vel = utils.movedim(vel, -dim - 1, -1)
    if vel_no_batch and grid_no_batch:
        vel = vel[0]
    return vel
Esempio n. 8
0
def smart_push(tensor, grid, shape=None):
    """Pull iff grid is defined (+ add/remove batch dim).

    Parameters
    ----------
    tensor : (channels, *input_shape) tensor
        Input volume
    grid : (*input_shape, D) tensor or None
        Sampling grid
    shape : (D,) tuple[int], default=input_shape
        Output shape

    Returns
    -------
    pushed : (channels, *output_shape) tensor
        Sampled volume

    """
    if grid is None:
        return tensor
    return spatial.grid_push(tensor[None, ...], grid[None, ...], shape)[0]
Esempio n. 9
0
def push1d(img, grid, **kwargs):
    """Push an image by a transform along the last dimension

    This is the adjoint of `pull1d`.

    Parameters
    ----------
    img : (K, *spatial) tensor, Image
    grid : (*spatial) tensor, Sampling grid

    Returns
    -------
    pushed_img : (K, *spatial) tensor

    """
    if grid is None:
        return img
    kwargs.setdefault('extrapolate', True)
    kwargs.setdefault('bound', 'dft')
    img, grid = img.unsqueeze(-2), grid.unsqueeze(-1)
    pushed = grid_push(img, grid, **kwargs).squeeze(-2)
    return pushed
Esempio n. 10
0
    def forward(self, x, min=None, max=None, mask=None):
        """

        Parameters
        ----------
        x : (..., N, 2) tensor
            Input multivariate vector
        min : (..., 2) tensor, optional
        max : (..., 2) tensor, optional
        mask : (..., N) tensor, optional

        Returns
        -------
        h : (..., B, B) tensor
            Joint histogram

        """
        shape = x.shape
        x, min, max = self._prepare(x, min, max)

        # push data into the histogram
        #   hidden feature: tell pullpush to use +/- 0.5 tolerance when
        #   deciding if a coordinate is inbounds.
        extrapolate = self.extrapolate or 2
        if mask is None:
            h = spatial.grid_count(x[:, None], self.n, self.order, self.bound,
                                   extrapolate)[:, 0]
        else:
            mask = mask.to(x.device, x.dtype)
            h = spatial.grid_push(mask, x[:, None], self.n, self.order,
                                  self.bound, extrapolate)[:, 0]
        h = h.to(x.dtype)
        h = h.reshape([*shape[:-2], *h.shape[-2:]])

        if self.fwhm:
            h = spatial.smooth(h, fwhm=self.fwhm, bound=self.bound, dim=2)

        return h, min, max
Esempio n. 11
0
def _jhistc_forward(x,
                    bins,
                    w=None,
                    order=0,
                    bound='replicate',
                    extrapolate=True):
    """Build joint histogram.

    The input must already be a soft mapping to bins indices.

    Parameters
    ----------
    x : (b, n, 2) tensor
    bins : int or (int, int)
    w : ([b], n) tensor, optional
    order : int or (int, int), default=0
    bound : {'zero', 'nearest'}, default='nearest'
    extrapolate : bool, default=True

    Returns
    -------
    h : (b, bins, bins) tensor

    """
    bins = py.make_list(bins, 2)
    extrapolate = 1 if extrapolate else 2
    opt = dict(shape=bins,
               interpolation=order,
               bound=bound,
               extrapolate=extrapolate)
    x = x.unsqueeze(-3)  # make 2d spatial
    if w is None:
        h = grid_count(x, **opt)
    else:
        w = w.unsqueeze(-2).unsqueeze(-2)  # make 2d spatial + add channel
        h = grid_push(w, x, **opt)
        h = h.squeeze(-3)  # drop channel
    return h
Esempio n. 12
0
    def __call__(self,
                 vel,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False):
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        dim = vel.shape[-1]
        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        in_line_search = not grad and not hess
        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # forward
        if self.kernel is None:
            self.kernel = spatial.greens(vel.shape[-dim - 1:-1], **self.prm,
                                         **utils.backend(vel))
        grid = spatial.shoot(vel, self.kernel, steps=self.steps, **self.prm)
        warped = spatial.grid_pull(self.moving,
                                   grid,
                                   bound='dct2',
                                   extrapolate=True)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, self.fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients
        if grad is not False or hess is not False:
            if self.mugrad is None:
                self.mugrad = spatial.diff(self.moving,
                                           dim=list(range(-dim, 0)),
                                           bound='dct2')
            if grad is not False:
                grad = grad.neg_()  # "final inverse" to "initial"
                grad = spatial.grid_push(grad, grid)
                grad = jg(self.mugrad, grad)
            if hess is not False:
                hess = spatial.grid_push(hess, grid)
                hess = jhj(self.mugrad, hess)

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=True)
        llv = 0.5 * (vel * vgrad).sum()
        if grad is not False:
            grad += vgrad
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = llx + llv
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 13
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 14
0
    def __call__(self,
                 vel,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        vel : (..., *spatial, dim) tensor, Displacement
        grad : Whether to compute and return the gradient wrt `vel`
        hess : Whether to compute and return the Hessian wrt `vel`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., *spatial, dim) tensor, optional, Gradient wrt velocity
        h : (..., *spatial, ?) tensor, optional, Hessian wrt velocity
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        dim = vel.shape[-1]
        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        # in_line_search = not grad and not hess
        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # forward
        if self.id is None:
            self.id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                            **utils.backend(vel))
        grid = self.id + vel
        warped = spatial.grid_pull(self.moving, grid, **pullopt)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess and not hessmov:
            llx = self.loss.loss(warped, self.fixed)
        elif not hess and not hessmov:
            llx, grad = self.loss.loss_grad(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            if grad is not False:
                grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=False)
        llv = 0.5 * (vel * vgrad).sum()
        if grad is not False:
            grad += vgrad
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = llx + llv
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Esempio n. 15
0
def smart_push(image, grid, **kwargs):
    """spatial.grid_push that accepts None grid"""
    if image is None or grid is None:
        return image
    return spatial.grid_push(image, grid, **kwargs)
Esempio n. 16
0
def grid_inv(grid, type='grid', lam=0.1, bound='dft', extrapolate=True):
    """Invert a dense deformation (or displacement) grid
    
    Notes
    -----
    The deformation/displacement grid must be expressed in 
    voxels, and map from/to the same lattice.
    
    Let `f = id + d` be the transformation. The inverse 
    is obtained as `id - (k * (f.T @ d)) / (k * (f.T @ 1))`
    where `k` is a smothing kernel, `f.T @ _` is the adjoint 
    operation ("push") of `f @ _` ("pull"). and `1` is an 
    image of ones.
    
    
    Parameters
    ----------
    grid : (..., *spatial, dim)
        Transformation (or displacement) grid
    type : {'grid', 'disp'}, default='grid'
        Type of deformation.
    lam : float, default=0.1
        Regularisation
    bound : str, default='dft'
    extrapolate : bool, default=True
        
    Returns
    -------
    grid_inv : (..., *spatial, dim)
        Inverse transformation (or displacement) grid
    
    """
    # get shape components
    dim = grid.shape[-1]
    shape = grid.shape[-(dim + 1):-1]
    batch = grid.shape[:-(dim + 1)]
    grid = grid.reshape([-1, *shape, dim])
    backend = dict(dtype=grid.dtype, device=grid.device)

    # get displacement
    identity = spatial.identity_grid(shape, **backend)
    if type == 'grid':
        disp = grid - identity
    else:
        disp = grid
        grid = disp + identity

    # push displacement
    push_opt = dict(bound=bound, extrapolate=extrapolate)
    disp = core.utils.movedim(disp, -1, 1)
    disp = spatial.grid_push(disp, grid, **push_opt)
    count = spatial.grid_count(grid, **push_opt)

    # Fill missing values using regularised least squares
    disp = spatial.solve_field_sym(count,
                                   disp,
                                   membrane=0.1,
                                   bound='dft',
                                   dim=dim)
    disp = core.utils.movedim(disp, 1, -1)
    disp = disp.reshape([*batch, *shape, dim])

    if type == 'grid':
        return identity - disp
    else:
        return -disp
Esempio n. 17
0
def _proj_apply(operator,
                dat,
                po,
                method='super-resolution',
                bound='zero',
                interpolation='linear'):
    """ Applies operator A, At  or AtA (for denoising or super-resolution).

    Args:
        operator (string): Either 'A', 'At', 'AtA' or 'none'.
        dat (torch.tensor()): Image data (1, 1, X_in, Y_in, Z_in).
        po (_proj_op()): Encodes projection operator, has the following fields:
            po.mat_x: Low-res affine matrix.
            po.mat_y: High-res affine matrix.
            po.mat_yx: Intermediate affine matrix.
            po.dim_x: Low-res image dimensions.
            po.dim_y: High-res image dimensions.
            po.dim_yx: Intermediate image dimensions.
            po.ratio: The ratio (low-res voxel_size)/(high-res voxel_size).
            po.smo_ker: Smoothing kernel (slice-profile).
        method (string): Either 'denoising' or 'super-resolution' (default).
        bound (str, optional): Bound for nitorch push/pull, defaults to 'zero'.
        interpolation (int, optional): Interpolation order, defaults to linear.

    Returns:
        dat (torch.tensor()): Projected image data (1, 1, X_out, Y_out, Z_out).

    """
    # Sanity check
    if operator not in ['A', 'At', 'AtA', 'none']:
        raise ValueError('Undefined operator')
    if method not in ['denoising', 'super-resolution']:
        raise ValueError('Undefined method')
    if operator == 'none':
        # No projection
        return dat
    # Get data type and device
    dtype = dat.dtype
    device = dat.device
    # Parse required projection info
    mat_x = po.mat_x
    mat_y = po.mat_y
    mat_yx = po.mat_yx
    rigid = po.rigid
    dim_x = po.dim_x
    dim_y = po.dim_y
    dim_yx = po.dim_yx
    ratio = po.ratio
    smo_ker = po.smo_ker
    scl = po.scl
    dim_thick = po.dim_thick
    if method == 'super-resolution':
        dim = dim_yx
        mat = rigid.mm(mat_yx).solve(mat_y)[0]  # mat_y\rigid*mat_yx
    elif method == 'denoising':
        dim = dim_x
        mat = rigid.mm(mat_x).solve(mat_y)[0]  # mat_y\rigid*mat_x
    # Smoothing operator
    if len(ratio) == 3:  # 3D
        conv = lambda x: F.conv3d(x, smo_ker, stride=ratio)
        conv_transpose = lambda x: F.conv_transpose3d(x, smo_ker, stride=ratio)
    else:  # 2D
        conv = lambda x: F.conv2d(x, smo_ker, stride=ratio)
        conv_transpose = lambda x: F.conv_transpose2d(x, smo_ker, stride=ratio)
    # Get grid
    grid = affine_grid(mat.type(dat.dtype), dim, jitter=False)[None, ...]
    # Apply projection
    if method == 'super-resolution':
        extrapolate = False
        if operator == 'A':
            dat = grid_pull(dat,
                            grid,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
            dat = conv(dat)
            if scl != 0:
                dat = _apply_scaling(dat, scl, dim_thick)
        elif operator == 'At':
            if scl != 0:
                dat = _apply_scaling(dat, scl, dim_thick)
            dat = conv_transpose(dat)
            dat = grid_push(dat,
                            grid,
                            shape=dim_y,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
        elif operator == 'AtA':
            dat = grid_pull(dat,
                            grid,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
            dat = conv(dat)
            if scl != 0:
                dat = _apply_scaling(dat, 2 * scl, dim_thick)
            dat = conv_transpose(dat)
            dat = grid_push(dat,
                            grid,
                            shape=dim_y,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
    elif method == 'denoising':
        extrapolate = False
        if operator == 'A':
            dat = grid_pull(dat,
                            grid,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
        elif operator == 'At':
            dat = grid_push(dat,
                            grid,
                            shape=dim_y,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
        elif operator == 'AtA':
            dat = grid_pull(dat,
                            grid,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
            dat = grid_push(dat,
                            grid,
                            shape=dim_y,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)

    return dat