Exemple #1
0
def _atlas_align(dat, mat, rigid=True, pth_atlas=None):
    """Affinely align image to some atlas space.

    Parameters
    ----------
    dat : [N, ...], tensor_like
        List of image volumes.
    mat : [N, ...], tensor_like
        List of affine matrices.
    rigid = bool, default=True
        Do rigid alignment, else does rigid+isotropic scaling.
    pth_atlas : str, optional
        Path to atlas image to match to. Uses Brain T1w atlas by default.

    Returns
    ----------
    mat_a : (N, 4, 4) tensor_like
        Transformation aligning to MNI space as M_mni\M_mov.
    mat_mni : (4, 4), tensor_like
        Affine matrix of MNI image.
    dim_mni : (3,), tuple, list, tensor_like
        Image dimensions of MNI image.
    mat_cso : (N, 4, 4) tensor_like
        CSO transformation.

    """
    if pth_atlas is None:
        # Get path to nitorch's T1w intensity atlas
        pth_atlas = fetch_data('atlas_t1')
    # Get number of input images
    N = len(dat)
    # Append atlas at the end of input data
    dat_mni, mat_mni, _ = _format_input(pth_atlas, device=dat[0].device,
                                        rand=True, cutoff=(0.0005, 0.9995))
    dat.append(dat_mni[0])
    mat.append(mat_mni[0])
    # Align to MNI atlas.
    group = 'CSO'
    _, mat_mni, dim_mni, q = _affine_align(dat, mat,
         group=group, samp=(3, 1.5), cost_fun='nmi', fix=N,
         verbose=False, mean_space=False)
    # Remove atlas
    q = q[:N, ...]
    dat = dat[:N]
    mat = mat[:N]
    # Get matrix representation
    mat_cso = expm(q, affine_basis(group=group))
    if rigid:
        # Extract only rigid part
        group = 'SE'
        q = q[..., :6]
    # Get matrix representation
    mat_a = expm(q, affine_basis(group=group))

    return mat_a, mat_mni, dim_mni, mat_cso
Exemple #2
0
def _init_reg(x, sett):
    """ Initialise registration.

    """
    # Total number of observations
    N = sum([len(xn) for xn in x])
    # Set rigid affine basis
    sett.rigid_basis = affine_basis(
        group='SE', device=sett.device, dtype=torch.float64)
    fix = 0  # Fixed image index

    # Make input for nitorch affine align
    imgs = []
    for c in range(len(x)):
        for n in range(len(x[c])):
            imgs.append([x[c][n].dat, x[c][n].mat])

    if sett.do_coreg and N > 1:
        # Align images, pairwise, to fixed image (fix)
        t0 = _print_info('init-reg', sett, 'co', 'begin', N)
        mat_a = affine_align(imgs, fix=fix, device=sett.device)[1]
        # Apply coreg transform
        i = 0
        for c in range(len(x)):
            for n in range(len(x[c])):
                imgs[i][1] = imgs[i][1].solve(mat_a[i, ...])[0]
                i += 1
        _print_info('init-reg', sett, 'co', 'finished', N, t0)

    if sett.do_atlas_align:
        # Align fixed image to atlas space, and apply transformation to
        # all images
        t0 = _print_info('init-reg', sett, 'atlas', 'begin', N)
        imgs1 = [imgs[fix]]
        _, mat_a, _, mat_cso = atlas_align(imgs1, rigid=sett.atlas_rigid, device=sett.device)
        _print_info('init-reg', sett, 'atlas', 'finished', N, t0)
        # Apply atlas registration transform
        i = 0
        for c in range(len(x)):
            for n in range(len(x[c])):
                imgs[i][1] = imgs[i][1].solve(mat_a)[0]
                i += 1

    # Modify image affine (label uses the same as the image, so no need to modify that one)
    i = 0
    for c in range(len(x)):
        for n in range(len(x[c])):
            x[c][n].mat = imgs[i][1]
            i += 1

    # Init rigid parameters (for unified rigid registration)
    for c in range(len(x)):  # Loop over channels
        for n in range(len(x[c])):  # Loop over observations of channel c
            x[c][n].rigid_q = torch.zeros(sett.rigid_basis.shape[0],
                device=sett.device, dtype=torch.float64)

    return x, sett
Exemple #3
0
 def free(self):
     """Free the next batch/ladder of parameters"""
     if not self.freeable():
         return
     nb_prm = len(self.optdat) if hasattr(self, 'optdat') else 0
     nb_t = self.dim
     nb_r = self.dim * (self.dim - 1) // 2
     nb_z = self.dim
     self.dat = self.dat.detach()
     if hasattr(self, 'optdat'):
         self.optdat = self.optdat.detach()
         self.dat = torch.cat([self.optdat.detach(), self.dat[nb_prm:]])
     if nb_prm == 0:
         print('Free translations')
         self.optdat = torch.nn.Parameter(self.dat[:nb_t],
                                          requires_grad=True)
         self.dat = torch.cat([self.optdat, self.dat[nb_t:]])
         self.basis = spatial.affine_basis('T', self.dim)
     elif nb_prm == nb_t:
         print('Free rotations')
         self.optdat = torch.nn.Parameter(self.dat[:nb_t + nb_r],
                                          requires_grad=True)
         self.dat = torch.cat([self.optdat, self.dat[nb_t + nb_r:]])
         self.basis = spatial.affine_basis('SE', self.dim)
     elif nb_prm == nb_t + nb_r:
         print('Free isotropic scaling')
         self.optdat = torch.nn.Parameter(self.dat[:nb_t + nb_r + 1],
                                          requires_grad=True)
         self.dat = torch.cat([self.optdat, self.dat[nb_t + nb_r + 1:]])
         self.basis = spatial.affine_basis('CSO', self.dim)
     elif nb_prm == nb_t + nb_r + 1:
         print('Free full affine')
         self.dat[nb_t + nb_r] /= nb_z**0.5
         self.dat[nb_t + nb_r + 1] = self.dat[nb_t + nb_r]
         self.dat[nb_t + nb_r + 2] = self.dat[nb_t + nb_r]
         self.optdat = torch.nn.Parameter(self.dat, requires_grad=True)
         self.dat = self.optdat
         self.basis = spatial.affine_basis('Aff+', self.dim)
Exemple #4
0
class Translation(Linear):
    name = 'translation'
    basis = spatial.affine_basis('T', 3)
    nb_prm = staticmethod(lambda dim: dim)
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
    def __call__(self, logaff, grad=False, hess=False, in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            logaff.requires_grad_()
            if logaff.grad is not None:
                logaff.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(logaff, grad, in_line_search=in_line_search)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(logaff, grad, in_line_search=in_line_search)

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                   and (self.n_iter - 1) % logplot == 0

        # jitter
        # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True,
        #                             **utils.backend(self.fixed))
        # fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        # del idj
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape, **utils.backend(logaff))
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        llx = self.loss.loss(warped, fixed)
        del warped

        # print objective
        lll = llx
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [lll]
        if grad is not False:
            lll.backward()
            grad = logaff.grad.clone()
            out.append(grad)
        logaff.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]
def register(fixed=None,
             moving=None,
             dim=None,
             loss='mse',
             basis='CSO',
             optim='ogm',
             max_iter=500,
             lr=1,
             ls=6,
             plot=False,
             klosure=RegisterStep,
             logaff=None,
             verbose=True):
    """Affine registration between two images using Lie groups.

    Parameters
    ----------
    fixed : (..., K, *spatial) tensor
        Fixed image
    moving : (..., K, *spatial) tensor
        Moving image
    dim : int, default=`fixed.dim() - 1`
        Number of spatial dimensions
    loss : {'mse', 'cat'} or OptimizationLoss, default='mse'
        'mse': Mean-squared error
        'cat': Categorical cross-entropy
    optim : {'relax', 'cg', 'gd', 'momentum', 'nesterov'}, default='ogm'
        'gn'        : Gauss-Newton
        'gd'        : Gradient descent
        'momentum'  : Gradient descent with momentum
        'nesterov'  : Nesterov-accelerated gradient descent
        'ogm'       : Optimized gradient descent (Kim & Fessler)
        'lbfgs'     : Limited-memory BFGS
    max_iter : int, default=100
        Maximum number of Gauss-Newton or Gradient descent iterations
    lr : float, default=1
        Learning rate.
    ls : int, default=6
        Number of line search iterations.
    plot : bool, default=False
        Plot progress

    Returns
    -------
    logaff : (...) tensor
        Displacement field.

    """

    # If no inputs provided: demo "circle to square"
    if fixed is None or moving is None:
        fixed, moving = phantoms.demo_register(cat=(loss == 'cat'))

    # init tensors
    fixed, moving = utils.to_max_backend(fixed, moving)
    dim = dim or (fixed.dim() - 1)
    basis = spatial.affine_basis(basis, dim, **utils.backend(fixed))
    if logaff is None:
        logaff = torch.zeros(len(basis), **utils.backend(fixed))
        # logaff = torch.zeros(12, **utils.backend(fixed))

    # init optimizer
    optim = regutils.make_iteroptim_affine(optim, lr, ls, max_iter)

    # init loss
    loss = losses.make_loss(loss, dim)

    # optimize
    if verbose:
        print(
            f'{"it":3s} | {"fit":^12s} + {"reg":^12s} = {"obj":^12s} | {"gain":^12s}'
        )
        print('-' * 63)
    closure = klosure(moving,
                      fixed,
                      loss,
                      basis=basis,
                      verbose=verbose,
                      plot=plot,
                      max_iter=optim.max_iter)
    logaff = optim.iter(logaff, closure)
    if verbose:
        print('')
    return logaff
Exemple #8
0
def diffeo(source,
           target,
           group='SE',
           image_loss=None,
           vel_loss=None,
           pull=None,
           preproc=False,
           max_iter=1000,
           device=None,
           origin='center',
           init=None,
           lr=1e-4,
           optim_affine=True,
           scheduler=ReduceLROnPlateau):
    """

    Parameters
    ----------
    source : path or tensor or (tensor, affine)
    target : path or tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    image_loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr: float, default=1e-4
    optim_affine : bool, default=True

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    vel : (D+1, D+1) tensor
        Initial velocity of the diffeomorphic transform.
        The full warp is `(aff @ aff_src).inv() @ aff_trg @ exp(vel)`
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    source = rescale(source)
    targe = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff = target_aff.clone()
        source_aff = source_aff.clone()
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=optim_affine)
    velocity = torch.zeros([batch, *target.shape[2:], dim], **backend)
    velocity = nn.Parameter(velocity, requires_grad=True)

    def pull(q, vel):
        grid = spatial.exp(vel)
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        grid = spatial.affine_matvec(aff, grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(vel_loss):
        vel_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if vel_loss is None else vel_loss
        vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \
              if optim_affine else [velocity]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    if scheduler is not None:
        scheduler = scheduler(optim, cooldown=5)

    # Optim loop
    loss_val = core.constants.inf
    loss_avg = 0
    for n_iter in range(1, max_iter + 1):

        loss_val0 = loss_val
        optim.zero_grad(set_to_none=True)
        moved = pull(parameters, velocity)
        loss_val = image_loss(moved, target) + vel_loss(velocity)
        loss_val.backward()
        optim.step()
        with torch.no_grad():
            loss_avg += loss_val

        if n_iter % 10 == 0:
            loss_avg /= 10
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(loss_avg)
                else:
                    scheduler.step()

            with torch.no_grad():
                if n_iter % 10 == 0:
                    print('{:4d} {:12.6f} | lr={:g}'.format(
                        n_iter, loss_avg.item(), optim.param_groups[0]['lr']),
                          end='\r')

            loss_avg = 0

    print('')
    with torch.no_grad():
        moved = pull(parameters, velocity)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return parameters, aff, velocity, moved
Exemple #9
0
class Affine(Linear):
    name = 'affine'
    basis = spatial.affine_basis('Aff+', 3)
    nb_prm = staticmethod(lambda dim: dim * (dim + 1))
Exemple #10
0
def _mean_space(Mat, Dim, vx=None):
    """Compute a (mean) model space from individual spaces.

    Args:
        Mat (torch.tensor): N subjects' orientation matrices (N, 4, 4).
        Dim (torch.tensor): N subjects' dimensions (N, 3).
        vx (torch.tensor|tuple|float, optional): Voxel size (3,), defaults to None (estimate from input).

    Returns:
        mat (torch.tensor): Mean orientation matrix (4, 4).
        dim (torch.tensor): Mean dimensions (3,).
        vx (torch.tensor): Mean voxel size (3,).

    Authors:
        John Ashburner, as part of the SPM12 software.

    """
    device = Mat.device
    dtype = Mat.dtype
    N = Mat.shape[0]  # Number of subjects
    inf = float('inf')
    one = torch.tensor(1.0, device=device, dtype=dtype)
    if vx is None:
        vx = torch.tensor([inf, inf, inf], device=device, dtype=dtype)
    if isinstance(vx, float) or isinstance(vx, int):
        vx = (vx, ) * 3
    if isinstance(vx, tuple) and len(vx) == 3:
        vx = torch.tensor([vx[0], vx[1], vx[2]], device=device, dtype=dtype)
    # To float64
    Mat = Mat.type(dtype)
    Dim = Dim.type(dtype)
    # Get affine basis
    basis = 'SE'
    dim = 3 if Dim[0, 2] > 1 else 2
    B = affine_basis(basis, dim, device=device, dtype=dtype)

    # Find combination of 90 degree rotations and flips that brings all
    # the matrices closest to axial
    Mat0 = Mat.clone()
    pmatrix = torch.tensor(
        [[0, 1, 2], [1, 0, 2], [2, 0, 1], [2, 1, 0], [0, 2, 1], [1, 2, 0]],
        device=device)

    for n in range(N):  # Loop over subjects
        vx1 = voxel_size(Mat[n, ...])
        R = Mat[n, ...].mm(
            torch.diag(torch.cat((vx1, one[..., None]))).inverse())[:-1, :-1]
        minss = inf
        minR = torch.eye(3, dtype=dtype, device=device)
        for i in range(6):  # Permute (= 'rotate + flip') axes
            R1 = torch.zeros((3, 3), dtype=dtype, device=device)
            R1[pmatrix[i, 0], 0] = 1
            R1[pmatrix[i, 1], 1] = 1
            R1[pmatrix[i, 2], 2] = 1
            for j in range(8):  # Mirror (= 'flip') axes
                fd = [(j & 1) * 2 - 1, (j & 2) - 1, (j & 4) / 2 - 1]
                F = torch.diag(torch.tensor(fd, dtype=dtype, device=device))
                R2 = F.mm(R1)
                ss = torch.sum((R.mm(R2.inverse()) -
                                torch.eye(3, dtype=dtype, device=device))**2)
                if ss < minss:
                    minss = ss
                    minR = R2
        rdim = torch.abs(minR.mm(Dim[n, ...][..., None] - 1))
        R2 = minR.inverse()
        R22 = R2.mm((torch.div(
            torch.sum(R2, dim=0, keepdim=True).t(), 2, rounding_mode='floor') -
                     1) * rdim)
        minR = torch.cat((R2, R22), dim=1)
        minR = torch.cat(
            (minR, torch.tensor([0, 0, 0, 1], device=device,
                                dtype=dtype)[None, ...]),
            dim=0)
        Mat[n, ...] = Mat[n, ...].mm(minR)

    # Average of the matrices in Mat
    mat = meanm(Mat)

    # If average involves shears, then find the closest matrix that does not
    # require them.
    C_ix = torch.tensor(
        [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15],
        device=device)  # column-major ordering from (4, 4) tensor
    p = _imatrix(mat)
    if torch.sum(p[[9, 10, 11]]**2) > 1e-8:
        B2 = torch.zeros((3, 4, 4), device=device, dtype=dtype)
        B2[0, 0, 0] = 1
        B2[1, 1, 1] = 1
        B2[2, 2, 2] = 1

        p = torch.zeros(9, device=device, dtype=dtype)
        for n_iter in range(10000):
            # Rotations + Translations
            R, dR = _expm(p[[0, 1, 2, 3, 4, 5]], B, grad_X=True)
            # Zooms
            Z, dZ = _expm(p[[6, 7, 8]], B2, grad_X=True)

            M = R.mm(Z)
            dM = torch.zeros((4, 4, 9), device=device, dtype=dtype)
            for n in range(6):
                dM[..., n] = dR[n, ...].mm(Z)
            for n in range(3):
                dM[..., 6 + n] = R.mm(dZ[n, ...])
            dM = dM.reshape((16, 9))
            d = M.flatten() - mat.flatten()
            gr = dM.t().mm(d[..., None])
            Hes = dM.t().mm(dM)
            p = p - lmdiv(Hes, gr)[:, 0]
            if torch.sum(gr**2) < 1e-8:
                break
        mat = M.clone()

    # Set required voxel size
    vx_out = vx.clone()
    vx = voxel_size(mat)
    vx_out[~torch.isfinite(vx_out)] = vx[~torch.isfinite(vx_out)]
    mat = mat.mm(torch.cat((vx_out / vx, one[..., None])).diag())
    vx = voxel_size(mat)

    # Ensure that the FoV covers all images, with a few voxels to spare
    mn_all = torch.zeros([3, N], device=device, dtype=dtype)
    mx_all = torch.zeros([3, N], device=device, dtype=dtype)
    for n in range(N):
        dm = Dim[n, ...]
        corners = torch.tensor([[1, dm[0], 1, dm[0], 1, dm[0], 1, dm[0]],
                                [1, 1, dm[1], dm[1], 1, 1, dm[1], dm[1]],
                                [1, 1, 1, 1, dm[2], dm[2], dm[2], dm[2]],
                                [1, 1, 1, 1, 1, 1, 1, 1]],
                               device=device,
                               dtype=dtype)
        M = lmdiv(mat, Mat0[n])
        vx1 = M[:-1, :].mm(corners)
        mx_all[..., n] = torch.max(vx1, dim=1)[0]
        mn_all[..., n] = torch.min(vx1, dim=1)[0]
    mx = mx_all.max(dim=1)[0]
    mn = mn_all.min(dim=1)[0]
    mx = torch.ceil(mx)
    mn = torch.floor(mn)

    # Make output dimensions and orientation matrix
    dim = mx - mn + 1  # Output dimensions
    off = torch.tensor([0, 0, 0], device=device, dtype=dtype)
    mat = mat.mm(
        torch.tensor([[1, 0, 0, mn[0] -
                       (off[0] + 1)], [0, 1, 0, mn[1] - (off[1] + 1)],
                      [0, 0, 1, mn[2] - (off[2] + 1)], [0, 0, 0, 1]],
                     device=device,
                     dtype=dtype))

    return mat, dim, vx
Exemple #11
0
def make_basis(name, dim, **backend):
    name = convert_basis(name)
    return spatial.affine_basis(name, dim, **backend)
Exemple #12
0
def diffeo(source, target, group='SE', origin='center',
           image_loss=None, vel_loss=None, pull=None, optim_affine=True,
           max_iter=1000, lr=0.1, min_lr=1e-7, init=None, device=None):
    """Diffeomorphic registration

    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across
       the channel dimension. The loss function is then responsible
       for unstacking the tensor and computing the appropriate
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as
       well as the same interpolation order. The advantage is that
       it simplifies the signature of this function.

    Parameters
    ----------
    source :  tensor or (tensor, affine)
        The source (moving) image, with shape (batch, channel, *spatial).
    target : tensor or (tensor, affine)
        The target (fixed) image, with shape (batch, channel, *spatial).
    group : {'tr', 'rot', 'rigid', 'sim', 'lin', 'aff'}, default='rigid'
        Affine sub-group to optimize.
    origin : {'native', 'center'}, default='center'
        Whether to rotate about the origin of the world-space ('native')
        or the center of the target field-of-view ('center').
        When the origin of the world-space is far off (say you are
        registering smaller blocks cropped from a larger MRI), it can
        be beneficiary to rotate about the center of the FOV.
    image_loss : callable(mov, fix) -> loss, default=MutualInfoLoss()
        A loss function that takestwo  inputs of shape
        (batch, channel, *spatial).
    vel_loss : float or callable(mov, fix) -> loss, default=BendingLoss()
        Either a factor to muultiply the bending loss with or a loss 
        function that takes two inputs of shape (batch, channel, *spatial).
    pull : dict
        interpolation : int, default=1
            Interpolation order
        bound : bound_like, default='dct2'
            Boundary condition
        extrapolate : bool, default=False
            Extrapolate out-of-bound data using the boundary conditions.
    max_iter : int, default=1000
        Maximum number of iterations
    lr : float, default=0.1
        Initial learning rate.
    min_lr : float, default=1e-7
        Minimum learning rate. The optimization is stopped once this
        learning rate is reached.
    device : {'cpu', 'cuda', 'cuda:<id>'}, optional
        Backend to use
    init : ([batch], nb_prm) tensor_like, default=0
        Initial guess for the affine parameters.

    Returns
    -------
    q : (batch, nb_prm) tensor
        Parameters
    aff : (batch, D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    vel : (batch, *shape, D) tensor
        Initial velocity
    moved : tensor
        Source image moved to target space.
    """
    group = affine_group_converter(group)
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    dim = source.dim() - 2

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff = target_aff.clone()
        source_aff = source_aff.clone()
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=optim_affine)
    velocity = torch.zeros([batch, *target.shape[2:], dim], **backend)
    velocity = nn.Parameter(velocity, requires_grad=True)

    def pull(q, vel):
        grid = spatial.exp(vel)
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        grid = spatial.affine_matvec(aff, grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(vel_loss):
        vel_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if vel_loss is None else vel_loss
        vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    min_lr = core.utils.make_list(min_lr, 2)
    opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \
              if optim_affine else [velocity]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    scheduler = ReduceLROnPlateau(optim)

    def forward():
        moved = pull(parameters, velocity)
        loss_val = image_loss(moved, target) + vel_loss(velocity)
        return loss_val

    # Optim loop
    loss_avg = 0
    for n_iter in range(1, max_iter + 1):

        optim.zero_grad(set_to_none=True)
        loss_val = forward()
        loss_val.backward()
        optim.step(forward)

        with torch.no_grad():
            loss_avg += loss_val
            if n_iter % 10 == 0:
                loss_avg /= 10
                scheduler.step(loss_avg)

                print('{:4d} {:12.6f} | lr={:g} '
                      .format(n_iter, loss_avg.item(),
                              optim.param_groups[0]['lr']),
                      end='\r')
                loss_avg = 0

        if (optim.param_groups[0]['lr'] < min_lr[0] and
                (len(optim.param_groups) == 1 or
                 optim.param_groups[1]['lr'] < min_lr[1])):
            print('\nConverged.')
            break

    print('')
    with torch.no_grad():
        moved = pull(parameters, velocity)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
    return (parameters.detach(),
            aff.detach(),
            velocity.detach(),
            moved.detach())
Exemple #13
0
def affine(source,
           target,
           group='SE',
           loss=None,
           pull=None,
           preproc=True,
           max_iter=1000,
           device=None,
           origin='center',
           init=None,
           lr=0.1,
           scheduler=ReduceLROnPlateau):
    """Affine registration
    
    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across 
       the channel dimension. The loss function is then responsible 
       for unstacking the tensor and computing the appropriate 
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as 
       well as the same interpolation order. The advantage is that 
       it simplifies the signature of this function.

    Parameters
    ----------
    source : tensor or (tensor, affine)
    target : tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr : float, default=0.1
    scheduler : Scheduler, default=ReduceLROnPlateau

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    if preproc:
        source = rescale(source)
        target = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=True)
    identity = spatial.identity_grid(target.shape[2:], **backend)

    def pull(q):
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
        grid = spatial.affine_matvec(aff[expd], identity)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if loss is None:
        loss_fn = nni.MutualInfoLoss()
        loss = lambda x, y: loss_fn(x, y)

    optim = torch.optim.Adam([parameters], lr=lr)
    if scheduler is not None:
        scheduler = scheduler(optim)

    # Optim loop
    loss_val = core.constants.inf
    for n_iter in range(1, max_iter + 1):

        loss_val0 = loss_val
        optim.zero_grad(set_to_none=True)
        moved = pull(parameters)
        loss_val = loss(moved, target)
        loss_val.backward()
        optim.step()
        if scheduler is not None and n_iter % 10 == 0:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(loss_val)
            else:
                scheduler.step()

        with torch.no_grad():
            if n_iter % 10 == 0:
                print('{:4d} {:12.6f} | lr={:g}'.format(
                    n_iter, loss_val.item(), optim.param_groups[0]['lr']),
                      end='\r')

    print('')
    with torch.no_grad():
        moved = pull(parameters)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return parameters, aff, moved
Exemple #14
0
class Rigid(Linear):
    name = 'rigid'
    basis = spatial.affine_basis('SE', 3)
    nb_prm = staticmethod(lambda dim: dim * (dim + 1) // 2)
Exemple #15
0
def _affine_align(dat, mat, cost_fun='nmi', group='SE', mean_space=False,
                  samp=(3, 1.5), optimiser='powell', fix=0, verbose=False,
                  fov=None, mx_int=1023, raw=False, jitter=False, fwhm=7.0):
    """Affine registration of a collection of images.

    Parameters
    ----------
    dat : [N, ...], tensor_like
        List of image volumes.
    mat : [N, ...], tensor_like
        List of affine matrices.
    cost_fun : str, default='nmi'
        Pairwise methods:
            * 'nmi'  : Normalised Mutual Information
            * 'mi'   : Mutual Information
            * 'ncc'  : Normalised Cross Correlation
            * 'ecc'  : Entropy Correlation Coefficient
        Groupwise methods:
            * 'njtv' : Normalised Joint Total variation
            * 'jtv'  : Joint Total variation
    group : str, default='SE'
        * 'T'   : Translations
        * 'SO'  : Special Orthogonal (rotations)
        * 'SE'  : Special Euclidean (translations + rotations)
        * 'D'   : Dilations (translations + isotropic scalings)
        * 'CSO' : Conformal Special Orthogonal
                  (translations + rotations + isotropic scalings)
        * 'SL'  : Special Linear (rotations + isovolumic zooms + shears)
        * 'GL+' : General Linear [det>0] (rotations + zooms + shears)
        * 'Aff+': Affine [det>0] (translations + rotations + zooms + shears)
    mean_space : bool, default=False
        Optimise a mean-space fit, only available if cost_fun='njtv'.
    samp : (float, ), default=(3, 1.5)
        Optimisation sampling steps (mm).
    optimiser : str, default='powell'
        'powell' : Optimisation method.
    fix : int, default=0
        Index of image to used as fixed image, not used if mean_space=True.
    verbose : bool, default=False
        Show registration results.
    fov : (2,) tuple, default=None
        A tuple with affine matrix (tensor_like) and dimensions (tuple) of mean space.
    mx_int : int, default=1023
        This parameter sets the max intensity in the images, which decides
        how many bins to use in the joint image histograms
        (e.g, mx_int=1023 -> H.shape = (1024, 1024)). This is only done if
        cost_fun is histogram-based.
    raw : bool, default=False
        Do no processing of input images -> work on raw data.
    jitter : bool, default=False
        Add random jittering to resampling grid.
    fwhm : float, default=7
        Full-width at half max of Gaussian kernel, for smoothing
        histogram.

    Returns
    ----------
    mat_a : (N, 4, 4), tensor_like
        Affine alignment matrices.
    mat_fix : (4, 4), tensor_like
        Affine matrix of fixed image.
    dim_fix : (3,), tuple, list, tensor_like
        Image dimensions of fixed image.
    q : (N, Nq), tensor_like
        Lie parameters.

    """
    with torch.no_grad():
        device = dat[0].device
        # Parse algorithm options
        opt = {'optimiser': optimiser,
               'cost_fun': cost_fun,
               'samp': samp,
               'fix': fix,
               'mean_space': mean_space,
               'verbose': verbose,
               'fov': fov,
               'group' : group,
               'raw': raw,
               'jitter': jitter,
               'mx_int' : mx_int,
               'fwhm' : fwhm}
        if not isinstance(opt['samp'], (list, tuple)):
            opt['samp'] = (opt['samp'], )
        # Some very basic sanity checks
        N = len(dat) # Number of input scans
        mov = list(range(N))  # Indices of images
        if opt['cost_fun'] in _costs_hist and opt['mean_space']:
            raise ValueError('Option mean_space=True not defined for {} cost!'.format(opt['cost_fun']))
        # Get affine basis
        B = affine_basis(group=opt['group'], device=device)
        Nq = B.shape[0]
        # Load data
        dat = _data_loader(dat, mat, opt)
        # Define fixed image space (mat_fix, dim_fix, vx_fix)
        if opt['mean_space']:
            # Use a mean-space
            if opt['fov']:
                # Mean-space given
                mat_fix = opt['fov'][0]
                dim_fix = opt['fov'][1]
            else:
                # Compute mean-space
                mat_fix, dim_fix = _get_mean_space(dat, mat)
            arg_grid = dim_fix
        else:
            # Use one of the input images
            mat_fix = mat[opt['fix']]
            dim_fix = dat[opt['fix']].shape[:3]
            mov.remove(opt['fix'])
            arg_grid = dat[opt['fix']]
        # Get voxel size of fixed image
        vx_fix = voxel_size(mat_fix)
        # Initial guess for registration parameter
        q = torch.zeros((N, Nq), dtype=torch.float64, device=device)
        if N < 2:
            # Return identity
            mat_a = torch.zeros((N, 4, 4),
                                dtype=torch.float64, device=device)
            for n in range(N):
                mat_a[m, ...] = expm(q[n, ...], basis=B)

            return mat_a, mat_fix, dim_fix
        # Do registration
        for s in opt['samp']:  # Loop over sub-sampling level
            # Get possibly sub-sampled fixed image, and its resampling grid
            dat_fix, grid = _get_dat_grid(
                arg_grid, vx_fix, s, jitter=opt['jitter'], device=device)
            # Do optimisation
            q, args = _fit_q(q, dat_fix, grid, mat_fix, dat, mat, mov,
                             B, s, opt)
    # To matrix form
    mat_a = torch.zeros((N, 4, 4),
                        dtype=torch.float64, device=device)
    for n in range(N):
        mat_a[n, ...] = expm(q[n, ...], basis=B)

    return mat_a, mat_fix, dim_fix, q
Exemple #16
0
class Similitude(Linear):
    name = 'similitude'
    basis = spatial.affine_basis('CSO', 3)
    nb_prm = staticmethod(lambda dim: dim * (dim + 1) // 2 + 1)
Exemple #17
0
def _test_cost(dat, mat, cost_fun='nmi', group='SE', mean_space=False,
               samp=2, ix_par=0, jitter=False, x_step=0.1, x_mn_mx=30,
               verbose=False, mx_int=1023, raw=False, fwhm=7.0):
    """Check cost function behaviour by keeping one image fixed and re-aligning
    a second image by modifying one of the affine parameters. Plots cost vs.
    aligment when finished.

    """
    with torch.no_grad():
        device = dat[0].device
        # Parse algorithm options
        opt = {'cost_fun': cost_fun,
               'samp': samp,
               'mean_space': mean_space,
               'verbose': verbose,
               'raw': raw,
               'mx_int' : mx_int,
               'fwhm' : fwhm}
        if not isinstance(opt['samp'], (list, tuple)):
            opt['samp'] = (opt['samp'], )
        # Some very basic sanity checks
        N = len(dat)
        if N != 2:
            raise ValueError('N != 2')
        mov = list(range(N))  # Indices of images
        fix_img = 0
        mov_img = 1
        if opt['cost_fun'] in _costs_hist and opt['mean_space']:
            raise ValueError('Option mean_space=True not defined for {} cost!'.format(opt['cost_fun']))
        # Load data
        dat = _data_loader(dat, mat, opt)
        # Get full 12 parameter affine basis
        B = affine_basis(group='Aff+', device=device)
        Nq = B.shape[0]
        # Range of parameter
        x = torch.arange(start=-x_mn_mx, end=x_mn_mx, step=x_step, dtype=torch.float32)
        if opt['mean_space']:
            # Use mean-space, so make sure that maximum misalignment is represented
            # in the input to _get_mean_space()
            mat_mn = torch.zeros(Nq,
                dtype=torch.float64, device=device)
            mat_mx = torch.zeros(Nq,
                dtype=torch.float64, device=device)
            mat_mn[ix_par] = -x_mn_mx
            mat_mx[ix_par] = x_mn_mx
            mat1 = [expm(mat_mn, B).mm(mat[mov_img]),
                    expm(mat_mx, B).mm(mat[mov_img])]
            # Compute mean-space
            dat.append(torch.tensor(dat[mov_img].shape,
                       dtype=torch.float32, device=device))
            dat.append(torch.tensor(dat[mov_img].shape,
                       dtype=torch.float32, device=device))
            mat_fix, dim_fix = _get_mean_space(dat, mat + mat1)
            dat = dat[:2]
            arg_grid = dim_fix
        else:
            mat_fix = mat[fix_img]
            dim_fix = dat[fix_img].shape[:3]
            mov.remove(fix_img)
            arg_grid = dat[fix_img]
        # Get voxel size of fixed image
        vx_fix = voxel_size(mat_fix)
        # Initial guess
        q = torch.zeros((N, Nq), dtype=torch.float64, device=device)
        # Get subsampled fixed image and its resampling grid
        dat_fix, grid = _get_dat_grid(arg_grid,
            vx_fix, samp=opt['samp'][-1], jitter=jitter, device=device)
        # Iterate over a range of values
        costs = np.zeros(len(x))
        fig_ax = None  # Used for visualisation
        for i, xi in enumerate(x):
            # Change affine matrix a little bit
            q[fix_img, ix_par] = xi
            # Compute cost
            costs[i], res = _compute_cost(
                q, grid, dat_fix, mat_fix, dat, mat, mov, opt['cost_fun'], B, opt['mx_int'], opt['fwhm'], return_res=True)
            if opt['verbose']:
                fig_ax = show_slices(res, fig_ax=fig_ax, fig_num=1, cmap='coolwarm', title='x=' + str(xi))
            # print(costs[i])
        # Plot results
        if plt is None:
            return
        fig, ax = plt.subplots(num=2)
        ax.plot(x, costs)
        ax.set(xlabel='Value q[' + str(ix_par) + ']', ylabel='Cost',
               title=opt['cost_fun'].upper() + ' cost function (mean_space=' + str(opt['mean_space']) + ')')
        ax.grid()
        plt.show()
Exemple #18
0
 def basis(self):
     if self._basis is None:
         self._basis = spatial.affine_basis(self._basis_name, self.dim,
                                            **utils.backend(self.dat))
     return self._basis
Exemple #19
0
def ffd(source,
        target,
        grid_shape=10,
        group='SE',
        image_loss=None,
        def_loss=None,
        pull=None,
        preproc=True,
        max_iter=1000,
        device=None,
        origin='center',
        init=None,
        lr=1e-4,
        optim_affine=True,
        scheduler=ReduceLROnPlateau):
    """FFD (= cubic spline) registration
    
    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across 
       the channel dimension. The loss function is then responsible 
       for unstacking the tensor and computing the appropriate 
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as 
       well as the same interpolation order. The advantage is that 
       it simplifies the signature of this function.

    Parameters
    ----------
    source : tensor or (tensor, affine)
    target : tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr : float, default=0.1
    scheduler : Scheduler, default=ReduceLROnPlateau

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dft')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    if preproc:
        source = rescale(source)
        target = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        affine_parameters = torch.as_tensor(init, **backend).clone().detach()
        affine_parameters = affine_parameters.reshape([batch, nb_prm])
    else:
        affine_parameters = torch.zeros([batch, nb_prm], **backend)
    affine_parameters = nn.Parameter(affine_parameters,
                                     requires_grad=optim_affine)
    grid_shape = core.pyutils.make_list(grid_shape, dim)
    grid_parameters = torch.zeros([batch, *grid_shape, dim], **backend)
    grid_parameters = nn.Parameter(grid_parameters, requires_grad=True)

    def pull(q, grid):
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
        grid = spatial.affine_matvec(aff[expd], grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    def exp(prm):
        disp = spatial.resize_grid(prm,
                                   type='displacement',
                                   shape=target.shape[2:],
                                   interpolation=3,
                                   bound='dft')
        grid = disp + spatial.identity_grid(target.shape[2:], **backend)
        return disp, grid

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(def_loss):
        def_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if def_loss is None else def_loss
        def_loss = lambda x: factor * def_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    opt_prm = [{
        'params': affine_parameters,
        'lr': lr[1]
    }, {
        'params': grid_parameters,
        'lr': lr[0]
    }] if optim_affine else [grid_parameters]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    if scheduler is not None:
        scheduler = scheduler(optim, cooldown=5)


#     with torch.no_grad():
#         disp, grid = exp(grid_parameters)
#         moved = pull(affine_parameters, grid)
#         plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu())
#         plt.show()

# Optim loop
    loss_val = core.constants.inf
    loss_avg = 0
    for n_iter in range(max_iter):

        loss_val0 = loss_val
        zero_grad_([affine_parameters, grid_parameters])
        disp, grid = exp(grid_parameters)
        moved = pull(affine_parameters, grid)
        loss_val = image_loss(moved, target) + def_loss(disp[0])
        loss_val.backward()
        optim.step()

        with torch.no_grad():
            loss_avg += loss_val

        if n_iter % 10 == 0:
            #             print(affine_parameters)
            #             plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu())
            #             plt.show()

            loss_avg /= 10
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(loss_avg)
                else:
                    scheduler.step()

            with torch.no_grad():
                if n_iter % 10 == 0:
                    print('{:4d} {:12.6f} | lr={:g}'.format(
                        n_iter, loss_avg.item(), optim.param_groups[0]['lr']),
                          end='\r')

            loss_avg = 0

    print('')
    with torch.no_grad():
        moved = pull(affine_parameters, grid)
        aff = core.linalg.expm(affine_parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return affine_parameters, aff, grid_parameters, moved