Example #1
0
def pull1d(img, grid, grad=False, **kwargs):
    """Pull an image by a transform along the last dimension

    Parameters
    ----------
    img : (K, *spatial) tensor, Image
    grid : (*spatial) tensor, Sampling grid
    grad : bool, Sample gradients

    Returns
    -------
    warped_img : (K, *spatial) tensor
    warped_grad : (K, *spatial) tensor, if `grad`

    """
    if grid is None:
        if grad:
            bound = kwargs.get('bound', 'dft')
            return img, diff1d(img, dim=-1, bound=bound, side='c')
        else:
            return img, None
    kwargs.setdefault('extrapolate', True)
    kwargs.setdefault('bound', 'dft')
    img, grid = img.unsqueeze(-2), grid.unsqueeze(-1)
    warped = grid_pull(img, grid, **kwargs).squeeze(-2)
    if not grad:
        return warped
    grad = grid_grad(img, grid, **kwargs)
    grad = grad.squeeze(-1).squeeze(-2)
    return warped, grad
Example #2
0
def smart_grad(tensor, grid, **opt):
    """Pull gradients iff grid is defined (+ add/remove batch dim).

    Parameters
    ----------
    tensor : (channels, *input_shape) tensor
        Input volume
    grid : (*output_shape, D) tensor or None
        Sampling grid

    Returns
    -------
    pulled : (channels, *output_shape) tensor
        Sampled volume

    """
    # if grid is None:
    #     opt.pop('extrapolate', None)
    #     opt.pop('interpolation', None)
    #     return spatial.diff(tensor, dim=3, **opt)
    if grid is None:
        grid = spatial.identity_grid(tensor.shape[-3:],
                                     dtype=tensor.dtype,
                                     device=tensor.device)
    out = spatial.grid_grad(tensor, grid, **opt)
    return out
Example #3
0
def _deform_1d(img, disp, grad=False):
    img = utils.movedim(img, 0, -2)
    disp = disp.unsqueeze(-1)
    disp = spatial.add_identity_grid(disp)
    wrp = spatial.grid_pull(img, disp, bound=BND, extrapolate=True)
    wrp = utils.movedim(wrp, -2, 0)
    if not grad:
        return wrp, None
    grd = spatial.grid_grad(img, disp, bound=BND, extrapolate=True)
    grd = utils.movedim(grd.squeeze(-1), -2, 0)
    return wrp, grd
Example #4
0
    def grad_position(self, t):
        """Gradient of the evaluated position wrt time"""
        # convert (0, 1) to (0, n)
        shape = t.shape
        t = t.flatten()
        t = t.clamp(0, 1) * (len(self.waypoints) - 1)

        # interpolate
        y = self.coeff.T  # [D, K]
        t = t.unsqueeze(-1)  # [N, 1]
        g = grid_grad(y, t, interpolation=self.order, bound=self.bound)
        g = g.squeeze(-1).T  # [N, D]
        g = g.reshape([*shape, g.shape[-1]])
        g *= (len(self.waypoints) - 1)
        return g
Example #5
0
def _jhistc_backward(g,
                     x,
                     w=None,
                     order=0,
                     bound='replicate',
                     extrapolate=True,
                     gradx=True,
                     gradw=False):
    """Compute derivative of the joint histogram.

    The input must already be a soft mapping to bins indices.

    Parameters
    ----------
    g : (b, bins, bins) tensor
    x : (b, n, 2) tensor
    w : ([b], n) tensor, optional
    order : int, default=0
    bound : {'zero', 'nearest'}, default='nearest'
    extrapolate : bool, default=True
    gradx : bool, default=True
    gradw : bool, default=False

    Returns
    -------
    gx : (b, n, 2) tensor, if gradx
    gw : ([b], n) tensor, if gradw

    """
    extrapolate = 1 if extrapolate else 2
    opt = dict(interpolation=order, bound=bound, extrapolate=extrapolate)
    x = x.unsqueeze(-3)  # make 2d spatial
    g = g.unsqueeze(-3)  # add channel dimension
    out = []
    if gradx:
        gx = grid_grad(g, x, **opt)
        gx = gx.squeeze(-3).squeeze(-3)
        if w is not None:
            gx *= w.unsqueeze(-1)
        out.append(gx)
    if gradw and w is not None:
        gw = grid_pull(g, x, **opt)
        gw = gw.squeeze(-2).squeeze(-2)  # drop spatial + channel
        out.append(gw)
    elif gradw:
        out.append(None)
    return out[0] if len(out) == 1 else tuple(out)
Example #6
0
def pull1d(img, grid, dim, grad=False, **kwargs):
    if grid is None:
        if grad:
            bound = kwargs.get('bound', 'dft')
            return img, spatial.diff1d(img, dim=dim, bound=bound, side='c')
        else:
            return img, None
    kwargs.setdefault('extrapolate', True)
    kwargs.setdefault('bound', 'dft')
    img = core.utils.movedim(img, dim, -1).unsqueeze(-2)
    grid = core.utils.movedim(grid, dim, -1).unsqueeze(-1)
    warped = spatial.grid_pull(img, grid, **kwargs)
    warped = core.utils.movedim(warped.squeeze(-2), -1, dim)
    if not grad:
        return warped, None
    grad = spatial.grid_grad(img, grid, **kwargs)
    grad = core.utils.movedim(grad.squeeze(-1).squeeze(-2), -1, dim)
    return warped, grad
Example #7
0
    def eval_grad_position(self, t):
        """Evaluate position and its gradient wrt time"""
        # convert (0, 1) to (0, n)
        shape = t.shape
        t = t.flatten()
        t = t.clamp(0, 1) * (len(self.waypoints) - 1)

        # interpolate
        y = self.coeff.T  # [D, K]
        t = t.unsqueeze(-1)  # [N, 1]
        x = grid_pull(y, t, interpolation=self.order, bound=self.bound)
        x = x.T  # [N, D]
        g = grid_grad(y, t, interpolation=self.order, bound=self.bound)
        g = g.squeeze(-1).T  # [N, D]

        x = x.reshape([*shape, x.shape[-1]])
        g = g.reshape([*shape, g.shape[-1]])
        g *= (len(self.waypoints) - 1)
        return x, g
Example #8
0
    def pull_grad(self, grid, rotate=False):
        """Sample the image gradients at dense coordinates.

        Parameters
        ----------
        grid : (*spatial, dim) tensor or None
            Dense transformation field.
        rotate : bool, default=False
            Rotate the gradients using the Jacobian of the transformation.

        Returns
        -------
        grad : ([C], *spatial, dim) tensor

        """
        if grid is None:
            return self.grad()
        grad = spatial.grid_grad(self.dat, grid, bound=self.bound,
                                 extrapolate=self.extrapolate)
        if rotate:
            jac = spatial.grid_jacobian(grid)
            jac = jac.transpose(-1, -2)
            grad = linalg.matvec(jac, grad)
        return grad
Example #9
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Example #10
0
    def backward(self, x, g, min=None, max=None, hess=False, mask=None):
        """

        Parameters
        ----------
        x : (..., N, 2) tensor
            Input multidimensional vector
        g : (..., B, B) tensor
            Gradient with respect to the histogram
        min : (..., 2) tensor, optional
        max : (..., 2) tensor, optional

        Returns
        -------
        g : (..., N, 2) tensor
            Gradient with respect to x

        """
        if self.fwhm:
            g = spatial.smooth(g, fwhm=self.fwhm, bound=self.bound, dim=2)

        shape = x.shape
        x, min, max = self._prepare(x, min, max)
        nvox = x.shape[-2]
        min = min.unsqueeze(-2)
        max = max.unsqueeze(-2)
        g = g.reshape([-1, *g.shape[-2:]])

        extrapolate = self.extrapolate or 2
        if not hess:
            g = spatial.grid_grad(g[:, None], x[:, None], self.order,
                                  self.bound, extrapolate)
            g = g[:, 0].reshape(shape)
        else:
            # 1) Absolute value of adjoint of gradient
            # we want shapes
            #   o : [batch=1, channel=1, spatial=[1, vox], dim=2]
            #   g : [batch=1, channel=1, spatial=[B(mov), B(fix)]]
            #   x : [batch=1, spatial=[1, vox], dim=2]
            #    -> [batch=1, channel=1, spatial=[B(mov), B(fix)]]
            order = _spatial.inter_to_nitorch([self.order], True)
            bound = _spatial.bound_to_nitorch([self.bound], True)
            o = torch.ones_like(x)
            g.requires_grad_()  # triggers push
            o, = _spatial.grid_grad_backward(o[:, None,
                                               None], g[:, None], x[:, None],
                                             bound, order, extrapolate)
            g.requires_grad_(False)
            g *= o[:, 0]
            # 2) Absolute value of gradient
            #   g : [batch=1, channel=1, spatial=[B(mov), B(fix)]]
            #   x : [batch=1, spatial=[1, vox], dim=2]
            #    -> [batch=1, channel=1, spatial=[1, vox], 2]
            g = _spatial.grid_grad(g[:, None], x[:, None], bound, order,
                                   extrapolate)
            g = g.reshape(shape)

        # adjoint of affine function
        nn = torch.as_tensor(self.n, dtype=x.dtype, device=x.device)
        factor = nn / (max - min)
        if hess:
            factor = factor.square_()
        g = g.mul_(factor)
        if mask is not None:
            g *= mask[..., None]

        return g
Example #11
0
def fit_affine_tpm(dat,
                   tpm,
                   affine=None,
                   affine_tpm=None,
                   weights=None,
                   basis='affine',
                   fwhm=None,
                   joint=False,
                   prm=None,
                   max_iter_gn=100,
                   max_iter_em=32,
                   max_line_search=6,
                   progressive=False,
                   verbose=1):
    """

    Parameters
    ----------
    dat : (B, J|1, *spatial) tensor
    tpm : (B|1, K, *spatial) tensor
    affine : (4, 4) tensor
    affine_tpm : (4, 4) tensor
    weights : (B, 1, *spatial) tensor
    basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'}
    fwhm : float, default=J/32
    joint : bool, default=False
    max_iter_gn : int, default=100
    max_iter_em : int, default=32
    max_line_search : int, default=12
    progressive : bool, default=False

    Returns
    -------
    mi : (B,) tensor
    aff : (B, 4, 4) tensor
    prm : (B, F) tensor

    """
    dim = dat.dim() - 2

    # ------------------------------------------------------------------
    #       RECURSIVE PROGRESSIVE FIT
    # ------------------------------------------------------------------
    if progressive:
        nb_se = dim * (dim + 1) // 2
        nb_aff = dim * (dim + 1)
        basis_recursion = {'Aff+': 'CSO', 'CSO': 'SE', 'SE': 'T'}
        basis_nb_feat = {'Aff+': nb_aff, 'CSO': nb_se + 1, 'SE': nb_se}
        basis = convert_basis(basis)
        next_basis = basis_recursion.get(basis, None)
        if next_basis:
            *_, prm = fit_affine_tpm(dat,
                                     tpm,
                                     affine,
                                     affine_tpm,
                                     weights,
                                     basis=next_basis,
                                     fwhm=fwhm,
                                     joint=joint,
                                     prm=prm,
                                     max_iter_gn=max_iter_gn,
                                     max_iter_em=max_iter_em,
                                     max_line_search=max_line_search)
            B = len(dat)
            F = basis_nb_feat[basis]
            prm0 = prm
            prm = prm0.new_zeros([1 if joint else B, F])
            if basis == 'SE':
                prm[:, :dim] = prm0[:, :dim]
            else:
                nb_se = dim * (dim + 1) // 2
                prm[:, :nb_se] = prm0[:, :nb_se]
                if basis == 'Aff+':
                    prm[:, nb_se:nb_se + dim] = prm0[:, nb_se] * (dim**(-0.5))

    basis_name = basis

    # ------------------------------------------------------------------
    #       PREPARE
    # ------------------------------------------------------------------

    B = len(dat)
    if affine is None:
        affine = spatial.affine_default(dat.shape[-dim:])
    if affine_tpm is None:
        affine_tpm = spatial.affine_default(tpm.shape[-dim:])
    affine = affine.to(**utils.backend(tpm))
    affine_tpm = affine_tpm.to(**utils.backend(tpm))
    shape = dat.shape[-dim:]

    tpm = tpm.to(dat.device)
    basis = make_basis(basis, dim, **utils.backend(tpm))
    F = len(basis)

    if prm is None:
        prm = tpm.new_zeros([1 if joint else B, F])
    aff, gaff = linalg._expm(prm, basis, grad_X=True)

    em_opt = dict(fwhm=fwhm,
                  max_iter=max_iter_em,
                  weights=weights,
                  verbose=verbose - 2)
    drv_opt = dict(weights=weights)
    pull_opt = dict(bound='replicate', extrapolate=True)

    # ------------------------------------------------------------------
    #       OPTIMIZE
    # ------------------------------------------------------------------
    prior = None
    mi = torch.as_tensor(-float('inf'))
    delta = torch.zeros_like(prm)
    for n_iter in range(max_iter_gn):

        # --------------------------------------------------------------
        #       LINE SEARCH
        # --------------------------------------------------------------
        prior0, prm0, mi0 = prior, prm, mi
        armijo = 1
        success = False
        for n_ls in range(max_line_search):

            # --- take a step ------------------------------------------
            prm = prm0 - armijo * delta

            # --- build transformation field ---------------------------
            aff, gaff = linalg._expm(prm, basis, grad_X=True)
            phi = lmdiv(affine_tpm, mm(aff, affine))
            phi = spatial.affine_grid(phi, shape)

            # --- warp TPM ---------------------------------------------
            mov = spatial.grid_pull(tpm, phi, **pull_opt)

            # --- mutual info ------------------------------------------
            mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt)
            mi = mi / Nm

            success = mi.sum() > mi0.sum()
            if verbose >= 2:
                end = '\n' if verbose >= 3 else '\r'
                happy = ':D' if success else ':('
                print(f'(search) | {n_ls:02d} | {mi.mean():12.6g} | {happy}',
                      end=end)
            if success:
                break
            armijo *= 0.5
        # if verbose == 2:
        #     print('')

        # --------------------------------------------------------------
        #       DID IT WORK?
        # --------------------------------------------------------------

        if not success:
            prior, prm, mi = prior0, prm0, mi0
            break

        # DEBUG
        # plot_registration(dat, mov, f'{basis_name} | {n_iter}')

        space = ' ' * max(0, 6 - len(basis_name))
        if verbose >= 1:
            end = '\n' if verbose >= 2 else '\r'
            print(
                f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}',
                end=end)

        if mi.mean() - mi0.mean() < 1e-5:
            break

        # --------------------------------------------------------------
        #       GAUSS-NEWTON
        # --------------------------------------------------------------

        # --- derivatives ----------------------------------------------
        g, h = derivatives_intensity(mov, dat, prior, **drv_opt)

        # --- chain rule -----------------------------------------------
        gmov = spatial.grid_grad(tpm, phi, **pull_opt)
        if joint and len(mov) == 1:
            g = g.sum(0, keepdim=True)
            h = h.sum(0, keepdim=True)
        else:
            gmov = gmov.expand([B, *gmov.shape[1:]])
        gaff = lmdiv(affine_tpm, mm(gaff, affine))
        g, h = chain_rule(g, h, gmov, gaff, maj=False)
        del gmov

        if joint and len(g) > 1:
            g = g.sum(0, keepdim=True)
            h = h.sum(0, keepdim=True)

        # --- Gauss-Newton ---------------------------------------------
        delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1)

    if verbose == 1:
        print('')
    return mi, aff, prm
Example #12
0
    def __call__(self,
                 vel,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        vel : (..., *spatial, dim) tensor, Displacement
        grad : Whether to compute and return the gradient wrt `vel`
        hess : Whether to compute and return the Hessian wrt `vel`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., *spatial, dim) tensor, optional, Gradient wrt velocity
        h : (..., *spatial, ?) tensor, optional, Hessian wrt velocity
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        dim = vel.shape[-1]
        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        # in_line_search = not grad and not hess
        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # forward
        if self.id is None:
            self.id = spatial.identity_grid(vel.shape[-dim - 1:-1],
                                            **utils.backend(vel))
        grid = self.id + vel
        warped = spatial.grid_pull(self.moving, grid, **pullopt)

        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        vel,
                        cat=iscat,
                        dim=dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess and not hessmov:
            llx = self.loss.loss(warped, self.fixed)
        elif not hess and not hessmov:
            llx, grad = self.loss.loss_grad(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, self.fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            if grad is not False:
                grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)

        # add regularization term
        vgrad = spatial.regulariser_grid(vel, **self.prm, kernel=False)
        llv = 0.5 * (vel * vgrad).sum()
        if grad is not False:
            grad += vgrad
        del vgrad

        # print objective
        llx = llx.item()
        llv = llv.item()
        ll = llx + llv
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g}',
                    end='\r')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\r')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Example #13
0
def _rigid_match(dat_x,
                 dat_y,
                 po,
                 tau,
                 rigid,
                 sett,
                 CtC=None,
                 diff=False,
                 verbose=0):
    """ Computes the rigid matching term, and its gradient and Hessian (if requested).

    Args:
        dat_x (torch.tensor): Observed data (X0, Y0, Z0).
        dat_y (torch.tensor): Reconstructed data (X1, Y1, Z1).
        po (ProjOp): Projection operator.
        tau (torch.tensor): Noice precision.
        CtC (torch.tensor, optional): CtC(ones), used for super-res gradient calculation.
            Defaults to None.
        rigid (torch.tensor): Rigid transformation matrix (4, 4).
        diff (bool, optional): Compute derivatives, defaults to False.
        verbose (bool, optional): Show registration results, defaults to 0.
            0: No verbose
            1: Print convergence info to console
            2: Plot registration results using matplotlib

    Returns:
        ll (torch.tensor): Log-likelihood.
        gr (torch.tensor): Gradient (dim_x, 3).
        Hes (torch.tensor): Hessian (dim_x, 6).

    """
    # Projection info
    mat_x = po.mat_x
    mat_y = po.mat_y
    mat_yx = po.mat_yx
    dim_x = po.dim_x
    dim_yx = po.dim_yx
    ratio = po.ratio
    smo_ker = po.smo_ker
    dim_thick = po.dim_thick
    scl = po.scl

    # Init output
    ll = None
    gr = None
    Hes = None

    if sett.method == 'super-resolution':
        extrapolate = False
        dim = dim_yx
        mat = mat_yx
    elif sett.method == 'denoising':
        extrapolate = False
        dim = dim_x
        mat = mat_x

    # Get grid
    mat = rigid.mm(mat).solve(mat_y)[0]  # mat_y\rigid*mat
    grid = affine_grid(mat.type(torch.float32), dim, jitter=False)

    # Warp y and compute spatial derivatives
    dat_yx = grid_pull(dat_y,
                       grid[None, ...],
                       bound=sett.bound,
                       extrapolate=extrapolate,
                       interpolation=sett.interpolation)[0, 0, ...]
    if sett.method == 'super-resolution':
        dat_yx = F.conv3d(dat_yx[None, None, ...], smo_ker, stride=ratio)[0, 0,
                                                                          ...]
        if scl != 0:
            dat_yx = _apply_scaling(dat_yx, scl, dim_thick)
    if diff:
        gr = grid_grad(dat_y,
                       grid[None, ...],
                       bound=sett.bound,
                       extrapolate=extrapolate,
                       interpolation=sett.interpolation)[0, 0, ...]

    if verbose >= 2:  # Show images
        show_slices(torch.stack((dat_x, dat_yx, (dat_x - dat_yx)**2), 3),
                    fig_num=666,
                    colorbar=False,
                    flip=False)

    # Double and mask
    msk = (dat_x != 0)

    # Compute matching term
    ll = 0.5 * tau * torch.sum(
        (dat_x[msk] - dat_yx[msk])**2, dtype=torch.float64)

    if diff:
        # Difference
        diff = dat_yx - dat_x
        msk = msk & (dat_yx != 0)
        diff[~msk] = 0
        # Hessian
        Hes = torch.zeros(dim + (6, ),
                          device=dat_x.device,
                          dtype=torch.float32)
        Hes[:, :, :, 0] = gr[:, :, :, 0] * gr[:, :, :, 0]
        Hes[:, :, :, 1] = gr[:, :, :, 1] * gr[:, :, :, 1]
        Hes[:, :, :, 2] = gr[:, :, :, 2] * gr[:, :, :, 2]
        Hes[:, :, :, 3] = gr[:, :, :, 0] * gr[:, :, :, 1]
        Hes[:, :, :, 4] = gr[:, :, :, 0] * gr[:, :, :, 2]
        Hes[:, :, :, 5] = gr[:, :, :, 1] * gr[:, :, :, 2]
        if sett.method == 'super-resolution':
            Hes *= CtC[..., None]
            diff = F.conv_transpose3d(diff[None, None, ...],
                                      smo_ker,
                                      stride=ratio)[0, 0, ...]
        # Gradient
        gr *= diff[..., None]

    return ll, gr, Hes