Example #1
0
def _atlas_align(dat, mat, rigid=True, pth_atlas=None):
    """Affinely align image to some atlas space.

    Parameters
    ----------
    dat : [N, ...], tensor_like
        List of image volumes.
    mat : [N, ...], tensor_like
        List of affine matrices.
    rigid = bool, default=True
        Do rigid alignment, else does rigid+isotropic scaling.
    pth_atlas : str, optional
        Path to atlas image to match to. Uses Brain T1w atlas by default.

    Returns
    ----------
    mat_a : (N, 4, 4) tensor_like
        Transformation aligning to MNI space as M_mni\M_mov.
    mat_mni : (4, 4), tensor_like
        Affine matrix of MNI image.
    dim_mni : (3,), tuple, list, tensor_like
        Image dimensions of MNI image.
    mat_cso : (N, 4, 4) tensor_like
        CSO transformation.

    """
    if pth_atlas is None:
        # Get path to nitorch's T1w intensity atlas
        pth_atlas = fetch_data('atlas_t1')
    # Get number of input images
    N = len(dat)
    # Append atlas at the end of input data
    dat_mni, mat_mni, _ = _format_input(pth_atlas, device=dat[0].device,
                                        rand=True, cutoff=(0.0005, 0.9995))
    dat.append(dat_mni[0])
    mat.append(mat_mni[0])
    # Align to MNI atlas.
    group = 'CSO'
    _, mat_mni, dim_mni, q = _affine_align(dat, mat,
         group=group, samp=(3, 1.5), cost_fun='nmi', fix=N,
         verbose=False, mean_space=False)
    # Remove atlas
    q = q[:N, ...]
    dat = dat[:N]
    mat = mat[:N]
    # Get matrix representation
    mat_cso = expm(q, affine_basis(group=group))
    if rigid:
        # Extract only rigid part
        group = 'SE'
        q = q[..., :6]
    # Get matrix representation
    mat_a = expm(q, affine_basis(group=group))

    return mat_a, mat_mni, dim_mni, mat_cso
Example #2
0
def write_transforms(options):
    """Write transformations (affine and nonlin) on disk"""
    nonlin = None
    affine = None
    for trf in options.transformations:
        if isinstance(trf, struct.NonLinear):
            nonlin = trf
        else:
            affine = trf

    if affine:
        q = affine.dat
        B = affine.basis
        lin = linalg.expm(q, B)
        if torch.is_tensor(affine.shift):
            # include shift
            shift = affine.shift.to(dtype=lin.dtype, device=lin.device)
            eye = torch.eye(3, dtype=lin.dtype, device=lin.device)
            lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift)
        io.transforms.savef(lin.cpu(), affine.output, type=2)

    if nonlin:
        affine = nonlin.affine
        shape = nonlin.shape
        if isinstance(nonlin, struct.FFD):
            factor = [s/g for s, g in zip(shape, nonlin.dat.shape[:-1])]
            affine, _ = spatial.affine_resize(affine, shape, factor)
        io.volumes.savef(nonlin.dat.cpu(), nonlin.output, affine=affine.cpu())
Example #3
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Example #4
0
    def __call__(self, logaff, grad=False, hess=False, in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            logaff.requires_grad_()
            if logaff.grad is not None:
                logaff.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(logaff, grad, in_line_search=in_line_search)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(logaff, grad, in_line_search=in_line_search)

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                   and (self.n_iter - 1) % logplot == 0

        # jitter
        # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True,
        #                             **utils.backend(self.fixed))
        # fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        # del idj
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape, **utils.backend(logaff))
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        llx = self.loss.loss(warped, fixed)
        del warped

        # print objective
        lll = llx
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [lll]
        if grad is not False:
            lll.backward()
            grad = logaff.grad.clone()
            out.append(grad)
        logaff.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]
Example #5
0
def _test_cost(dat, mat, cost_fun='nmi', group='SE', mean_space=False,
               samp=2, ix_par=0, jitter=False, x_step=0.1, x_mn_mx=30,
               verbose=False, mx_int=1023, raw=False, fwhm=7.0):
    """Check cost function behaviour by keeping one image fixed and re-aligning
    a second image by modifying one of the affine parameters. Plots cost vs.
    aligment when finished.

    """
    with torch.no_grad():
        device = dat[0].device
        # Parse algorithm options
        opt = {'cost_fun': cost_fun,
               'samp': samp,
               'mean_space': mean_space,
               'verbose': verbose,
               'raw': raw,
               'mx_int' : mx_int,
               'fwhm' : fwhm}
        if not isinstance(opt['samp'], (list, tuple)):
            opt['samp'] = (opt['samp'], )
        # Some very basic sanity checks
        N = len(dat)
        if N != 2:
            raise ValueError('N != 2')
        mov = list(range(N))  # Indices of images
        fix_img = 0
        mov_img = 1
        if opt['cost_fun'] in _costs_hist and opt['mean_space']:
            raise ValueError('Option mean_space=True not defined for {} cost!'.format(opt['cost_fun']))
        # Load data
        dat = _data_loader(dat, mat, opt)
        # Get full 12 parameter affine basis
        B = affine_basis(group='Aff+', device=device)
        Nq = B.shape[0]
        # Range of parameter
        x = torch.arange(start=-x_mn_mx, end=x_mn_mx, step=x_step, dtype=torch.float32)
        if opt['mean_space']:
            # Use mean-space, so make sure that maximum misalignment is represented
            # in the input to _get_mean_space()
            mat_mn = torch.zeros(Nq,
                dtype=torch.float64, device=device)
            mat_mx = torch.zeros(Nq,
                dtype=torch.float64, device=device)
            mat_mn[ix_par] = -x_mn_mx
            mat_mx[ix_par] = x_mn_mx
            mat1 = [expm(mat_mn, B).mm(mat[mov_img]),
                    expm(mat_mx, B).mm(mat[mov_img])]
            # Compute mean-space
            dat.append(torch.tensor(dat[mov_img].shape,
                       dtype=torch.float32, device=device))
            dat.append(torch.tensor(dat[mov_img].shape,
                       dtype=torch.float32, device=device))
            mat_fix, dim_fix = _get_mean_space(dat, mat + mat1)
            dat = dat[:2]
            arg_grid = dim_fix
        else:
            mat_fix = mat[fix_img]
            dim_fix = dat[fix_img].shape[:3]
            mov.remove(fix_img)
            arg_grid = dat[fix_img]
        # Get voxel size of fixed image
        vx_fix = voxel_size(mat_fix)
        # Initial guess
        q = torch.zeros((N, Nq), dtype=torch.float64, device=device)
        # Get subsampled fixed image and its resampling grid
        dat_fix, grid = _get_dat_grid(arg_grid,
            vx_fix, samp=opt['samp'][-1], jitter=jitter, device=device)
        # Iterate over a range of values
        costs = np.zeros(len(x))
        fig_ax = None  # Used for visualisation
        for i, xi in enumerate(x):
            # Change affine matrix a little bit
            q[fix_img, ix_par] = xi
            # Compute cost
            costs[i], res = _compute_cost(
                q, grid, dat_fix, mat_fix, dat, mat, mov, opt['cost_fun'], B, opt['mx_int'], opt['fwhm'], return_res=True)
            if opt['verbose']:
                fig_ax = show_slices(res, fig_ax=fig_ax, fig_num=1, cmap='coolwarm', title='x=' + str(xi))
            # print(costs[i])
        # Plot results
        if plt is None:
            return
        fig, ax = plt.subplots(num=2)
        ax.plot(x, costs)
        ax.set(xlabel='Value q[' + str(ix_par) + ']', ylabel='Cost',
               title=opt['cost_fun'].upper() + ' cost function (mean_space=' + str(opt['mean_space']) + ')')
        ax.grid()
        plt.show()
Example #6
0
def _affine_align(dat, mat, cost_fun='nmi', group='SE', mean_space=False,
                  samp=(3, 1.5), optimiser='powell', fix=0, verbose=False,
                  fov=None, mx_int=1023, raw=False, jitter=False, fwhm=7.0):
    """Affine registration of a collection of images.

    Parameters
    ----------
    dat : [N, ...], tensor_like
        List of image volumes.
    mat : [N, ...], tensor_like
        List of affine matrices.
    cost_fun : str, default='nmi'
        Pairwise methods:
            * 'nmi'  : Normalised Mutual Information
            * 'mi'   : Mutual Information
            * 'ncc'  : Normalised Cross Correlation
            * 'ecc'  : Entropy Correlation Coefficient
        Groupwise methods:
            * 'njtv' : Normalised Joint Total variation
            * 'jtv'  : Joint Total variation
    group : str, default='SE'
        * 'T'   : Translations
        * 'SO'  : Special Orthogonal (rotations)
        * 'SE'  : Special Euclidean (translations + rotations)
        * 'D'   : Dilations (translations + isotropic scalings)
        * 'CSO' : Conformal Special Orthogonal
                  (translations + rotations + isotropic scalings)
        * 'SL'  : Special Linear (rotations + isovolumic zooms + shears)
        * 'GL+' : General Linear [det>0] (rotations + zooms + shears)
        * 'Aff+': Affine [det>0] (translations + rotations + zooms + shears)
    mean_space : bool, default=False
        Optimise a mean-space fit, only available if cost_fun='njtv'.
    samp : (float, ), default=(3, 1.5)
        Optimisation sampling steps (mm).
    optimiser : str, default='powell'
        'powell' : Optimisation method.
    fix : int, default=0
        Index of image to used as fixed image, not used if mean_space=True.
    verbose : bool, default=False
        Show registration results.
    fov : (2,) tuple, default=None
        A tuple with affine matrix (tensor_like) and dimensions (tuple) of mean space.
    mx_int : int, default=1023
        This parameter sets the max intensity in the images, which decides
        how many bins to use in the joint image histograms
        (e.g, mx_int=1023 -> H.shape = (1024, 1024)). This is only done if
        cost_fun is histogram-based.
    raw : bool, default=False
        Do no processing of input images -> work on raw data.
    jitter : bool, default=False
        Add random jittering to resampling grid.
    fwhm : float, default=7
        Full-width at half max of Gaussian kernel, for smoothing
        histogram.

    Returns
    ----------
    mat_a : (N, 4, 4), tensor_like
        Affine alignment matrices.
    mat_fix : (4, 4), tensor_like
        Affine matrix of fixed image.
    dim_fix : (3,), tuple, list, tensor_like
        Image dimensions of fixed image.
    q : (N, Nq), tensor_like
        Lie parameters.

    """
    with torch.no_grad():
        device = dat[0].device
        # Parse algorithm options
        opt = {'optimiser': optimiser,
               'cost_fun': cost_fun,
               'samp': samp,
               'fix': fix,
               'mean_space': mean_space,
               'verbose': verbose,
               'fov': fov,
               'group' : group,
               'raw': raw,
               'jitter': jitter,
               'mx_int' : mx_int,
               'fwhm' : fwhm}
        if not isinstance(opt['samp'], (list, tuple)):
            opt['samp'] = (opt['samp'], )
        # Some very basic sanity checks
        N = len(dat) # Number of input scans
        mov = list(range(N))  # Indices of images
        if opt['cost_fun'] in _costs_hist and opt['mean_space']:
            raise ValueError('Option mean_space=True not defined for {} cost!'.format(opt['cost_fun']))
        # Get affine basis
        B = affine_basis(group=opt['group'], device=device)
        Nq = B.shape[0]
        # Load data
        dat = _data_loader(dat, mat, opt)
        # Define fixed image space (mat_fix, dim_fix, vx_fix)
        if opt['mean_space']:
            # Use a mean-space
            if opt['fov']:
                # Mean-space given
                mat_fix = opt['fov'][0]
                dim_fix = opt['fov'][1]
            else:
                # Compute mean-space
                mat_fix, dim_fix = _get_mean_space(dat, mat)
            arg_grid = dim_fix
        else:
            # Use one of the input images
            mat_fix = mat[opt['fix']]
            dim_fix = dat[opt['fix']].shape[:3]
            mov.remove(opt['fix'])
            arg_grid = dat[opt['fix']]
        # Get voxel size of fixed image
        vx_fix = voxel_size(mat_fix)
        # Initial guess for registration parameter
        q = torch.zeros((N, Nq), dtype=torch.float64, device=device)
        if N < 2:
            # Return identity
            mat_a = torch.zeros((N, 4, 4),
                                dtype=torch.float64, device=device)
            for n in range(N):
                mat_a[m, ...] = expm(q[n, ...], basis=B)

            return mat_a, mat_fix, dim_fix
        # Do registration
        for s in opt['samp']:  # Loop over sub-sampling level
            # Get possibly sub-sampled fixed image, and its resampling grid
            dat_fix, grid = _get_dat_grid(
                arg_grid, vx_fix, s, jitter=opt['jitter'], device=device)
            # Do optimisation
            q, args = _fit_q(q, dat_fix, grid, mat_fix, dat, mat, mov,
                             B, s, opt)
    # To matrix form
    mat_a = torch.zeros((N, 4, 4),
                        dtype=torch.float64, device=device)
    for n in range(N):
        mat_a[n, ...] = expm(q[n, ...], basis=B)

    return mat_a, mat_fix, dim_fix, q
Example #7
0
def _compute_cost(q,
                  grid0,
                  dat_fix,
                  mat_fix,
                  dat,
                  mat,
                  mov,
                  cost_fun,
                  B,
                  mx_int,
                  fwhm,
                  return_res=False):
    """Compute registration cost function.

    Parameters
    ----------
    q : (N, Nq) tensor_like
        Lie algebra of affine registration fit.
    grid0 : (X1, Y1, Z1) tensor_like
        Sub-sampled image data's resampling grid.
    dat_fix : (X1, Y1, Z1) tensor_like
        Fixed image data.
    mat_fix : (4, 4) tensor_like
        Fixed affine matrix.
    dat : [N,] tensor_like
        List of input images.
    mat : [N,] tensor_like
        List of affine matrices.
    mov : [N,] int
        Indices of moving images.
    cost_fun : str
        Cost function to compute (see run_affine_reg).
    B : (Nq, N, N) tensor_like
        Affine basis.
    mx_int : int
        This parameter sets the max intensity in the images, which decides
        how many bins to use in the joint image histograms
        (e.g, mx_int=511 -> H.shape = (512, 512)).
    fwhm : float
        Full-width at half max of Gaussian kernel, for smoothing
        histogram.
    return_res : bool, default=False
        Return registration results for plotting.

    Returns
    ----------
    c : float
        Cost of aligning images with current estimate of q. If
        optimiser='powell', array_like, else tensor_like.
    res : tensor_like
        Registration results, for visualisation (only if return_res=True).

    """
    # Init
    device = grid0.device
    q = q.flatten()
    was_numpy = False
    if isinstance(q, np.ndarray):
        was_numpy = True
        q = torch.from_numpy(q).to(device)  # To torch tensor
    dm_fix = dat_fix.shape
    Nq = B.shape[0]
    N = torch.tensor(len(dat), device=device,
                     dtype=torch.float32)  # For modulating NJTV cost

    if cost_fun in _costs_edge:
        jtv = dat_fix.clone()
        if cost_fun == 'njtv':
            njtv = -dat_fix.sqrt()

    for i, m in enumerate(mov):  # Loop over moving images
        # Get affine matrix
        mat_a = expm(q[torch.arange(i * Nq, i * Nq + Nq)], B)
        # Compose matrices
        M = mat_a.mm(mat_fix).solve(mat[m])[0].type(
            torch.float32)  # mat_mov\mat_a*mat_fix
        # Transform fixed grid
        grid = affine_matvec(M, grid0)
        # Resample to fixed grid
        dat_new = grid_pull(dat[m][None, None, ...],
                            grid[None, ...],
                            bound='dft',
                            extrapolate=False,
                            interpolation=1)[0, 0, ...]
        if cost_fun in _costs_edge:
            jtv += dat_new
            if cost_fun == 'njtv':
                njtv -= dat_new.sqrt()

    # Compute the cost function
    res = None
    if cost_fun in _costs_hist:
        # Histogram based costs
        # ----------
        # Compute joint histogram
        # OBS: This function expects both images to have the same max and min intesities,
        # this is ensured by the _data_loader() function.
        H = _hist_2d(dat_fix, dat_new, mx_int, fwhm)
        res = H

        # Get probabilities
        pxy = H / H.sum()
        px = pxy.sum(dim=0, keepdim=True)
        py = pxy.sum(dim=1, keepdim=True)

        # Compute cost
        if cost_fun == 'mi':
            # Mutual information
            mi = torch.sum(pxy * torch.log2(pxy / py.mm(px)))
            c = -mi
        elif cost_fun == 'ecc':
            # Entropy Correlation Coefficient
            # Maes, Collignon, Vandermeulen, Marchal & Suetens (1997).
            # "Multimodality image registration by maximisation of mutual
            # information". IEEE Transactions on Medical Imaging 16(2):187-198
            mi = torch.sum(pxy * torch.log2(pxy / py.mm(px)))
            ecc = -2 * mi / (torch.sum(px * px.log2()) +
                             torch.sum(py * py.log2()))
            c = -ecc
        elif cost_fun == 'nmi':
            # Normalised Mutual Information
            # Studholme,  Hill & Hawkes (1998).
            # "A normalized entropy measure of 3-D medical image alignment".
            # in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
            nmi = (torch.sum(px * px.log2()) +
                   torch.sum(py * py.log2())) / torch.sum(pxy * pxy.log2())
            c = -nmi
        elif cost_fun == 'ncc':
            # Normalised Cross Correlation
            i = torch.arange(1,
                             pxy.shape[0] + 1,
                             device=device,
                             dtype=torch.float32)
            j = torch.arange(1,
                             pxy.shape[1] + 1,
                             device=device,
                             dtype=torch.float32)
            m1 = torch.sum(py * i[..., None])
            m2 = torch.sum(px * j[None, ...])
            sig1 = torch.sqrt(torch.sum(py[..., 0] * (i - m1)**2))
            sig2 = torch.sqrt(torch.sum(px[0, ...] * (j - m2)**2))
            i, j = torch.meshgrid(i - m1, j - m2)
            ncc = torch.sum(torch.sum(pxy * i * j)) / (sig1 * sig2)
            c = -ncc
    elif cost_fun in _costs_edge:
        # Normalised Joint Total Variation
        # M Brudfors, Y Balbastre, J Ashburner (2020).
        # "Groupwise Multimodal Image Registration using Joint Total Variation".
        # in MIUA 2020.
        jtv.sqrt_()
        if cost_fun == 'njtv':
            njtv += torch.sqrt(N) * jtv
            res = njtv
            c = torch.sum(njtv)
        else:
            res = jtv
            c = torch.sum(jtv)

    # _ = show_slices(res, fig_num=1, cmap='coolwarm')  # Can be uncommented for testing

    if was_numpy:
        # Back to numpy array
        c = c.cpu().numpy()

    if return_res:
        return c, res
    else:
        return c
Example #8
0
def write_data(options):

    device = torch.device(options.device)
    backend = dict(dtype=torch.float, device='cpu')

    need_inv = False
    for loss in options.losses:
        if loss.fixed and (loss.fixed.resliced or loss.fixed.updated):
            need_inv = True
            break

    # affine matrix
    lin = None
    for trf in options.transformations:
        if isinstance(trf, struct.Linear):
            q = trf.dat.to(**backend)
            B = trf.basis.to(**backend)
            lin = linalg.expm(q, B)
            if torch.is_tensor(trf.shift):
                # include shift
                shift = trf.shift.to(**backend)
                eye = torch.eye(3, **backend)
                lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift)
            break

    # non-linear displacement field
    d = None
    id = None
    d_aff = None
    for trf in options.transformations:
        if isinstance(trf, struct.FFD):
            d = trf.dat.to(**backend)
            d = ffd_exp(d, trf.shape, returns='disp')
            if need_inv:
                id = grid_inv(d)
            d_aff = trf.affine.to(**backend)
            break
        elif isinstance(trf, struct.Diffeo):
            d = trf.dat.to(**backend)
            if need_inv:
                id = spatial.exp(d[None], displacement=True, inverse=True)[0]
            d = spatial.exp(d[None], displacement=True)[0]
            d_aff = trf.affine.to(**backend)
            break

    # loop over image pairs
    for match in options.losses:

        moving = match.moving
        fixed = match.fixed
        prm = dict(interpolation=moving.interpolation,
                   bound=moving.bound,
                   extrapolate=moving.extrapolate,
                   device='cpu',
                   verbose=options.verbose)
        nonlin = dict(disp=d, affine=d_aff)
        if moving.updated:
            update(moving, moving.updated, lin=lin, nonlin=nonlin, **prm)
        if moving.resliced:
            reslice(moving, moving.resliced, like=fixed, lin=lin, nonlin=nonlin, **prm)
        if not fixed:
            continue
        prm = dict(interpolation=fixed.interpolation,
                   bound=fixed.bound,
                   extrapolate=fixed.extrapolate,
                   device='cpu',
                   verbose=options.verbose)
        nonlin = dict(disp=id, affine=d_aff)
        if fixed.updated:
            update(fixed, fixed.updated, inv=True, lin=lin, nonlin=nonlin, **prm)
        if fixed.resliced:
            reslice(fixed, fixed.resliced, inv=True, like=moving, lin=lin, nonlin=nonlin, **prm)
Example #9
0
    def forward():
        """Forward pass up to the loss"""

        loss = 0

        # affine matrix
        A = None
        for trf in options.transformations:
            trf.update()
            if isinstance(trf, struct.Linear):
                q = trf.optdat.to(**backend)
                # print(q.tolist())
                B = trf.basis.to(**backend)
                A = linalg.expm(q, B)
                if torch.is_tensor(trf.shift):
                    # include shift
                    shift = trf.shift.to(**backend)
                    eye = torch.eye(options.dim, **backend)
                    A = A.clone()  # needed because expm is a custom autograd.Function
                    A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift)
                for loss1 in trf.losses:
                    loss += loss1.call(q)
                break

        # non-linear displacement field
        d = None
        d_aff = None
        for trf in options.transformations:
            if not trf.isfree():
                continue
            if isinstance(trf, struct.FFD):
                d = trf.dat.to(**backend)
                d = ffd_exp(d, trf.shape, returns='disp')
                for loss1 in trf.losses:
                    loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break
            elif isinstance(trf, struct.Diffeo):
                d = trf.dat.to(**backend)
                if not trf.smalldef:
                    # penalty on velocity fields
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d = spatial.exp(d[None], displacement=True)[0]
                if trf.smalldef:
                    # penalty on exponentiated transform
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break

        # loop over image pairs
        for match in options.losses:
            if not match.fixed:
                continue
            nb_levels = len(match.fixed.dat)
            prm = dict(interpolation=match.moving.interpolation,
                       bound=match.moving.bound,
                       extrapolate=match.moving.extrapolate)
            # loop over pyramid levels
            for moving, fixed in zip(match.moving.dat, match.fixed.dat):
                moving_dat, moving_aff = moving
                fixed_dat, fixed_aff = fixed

                moving_dat = moving_dat.to(**backend)
                moving_aff = moving_aff.to(**backend)
                fixed_dat = fixed_dat.to(**backend)
                fixed_aff = fixed_aff.to(**backend)

                # affine-corrected moving space
                if A is not None:
                    Ms = affine_matmul(A, moving_aff)
                else:
                    Ms = moving_aff

                if d is not None:
                    # fixed to param
                    Mt = affine_lmdiv(d_aff, fixed_aff)
                    if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]):
                        g = smalldef(d)
                    else:
                        g = affine_grid(Mt, fixed_dat.shape[1:])
                        g = g + pull_grid(d, g)
                    # param to moving
                    Ms = affine_lmdiv(Ms, d_aff)
                    g = affine_matvec(Ms, g)
                else:
                    # fixed to moving
                    Mt = fixed_aff
                    Ms = affine_lmdiv(Ms, Mt)
                    g = affine_grid(Ms, fixed_dat.shape[1:])

                # pull moving image
                warped_dat = pull(moving_dat, g, **prm)
                loss += match.call(warped_dat, fixed_dat) / float(nb_levels)

                # import matplotlib.pyplot as plt
                # plt.subplot(1, 2, 1)
                # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.subplot(1, 2, 2)
                # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.show()

        return loss