Exemplo n.º 1
0
def _warp_image1(image,
                 target,
                 shape=None,
                 affine=None,
                 nonlin=None,
                 backward=False,
                 reslice=False):
    """Returns the warped image, with channel dimension last"""
    # build transform
    aff_right = target
    aff_left = spatial.affine_inv(image.affine)
    aff = None
    if affine:
        # exp = affine.iexp if backward else affine.exp
        exp = affine.exp
        aff = exp(recompute=False, cache_result=True)
        if backward:
            aff = spatial.affine_inv(aff)
    if nonlin:
        if affine:
            if affine.position[0].lower() in ('ms' if backward else 'fs'):
                aff_right = spatial.affine_matmul(aff, aff_right)
            if affine.position[0].lower() in ('fs' if backward else 'ms'):
                aff_left = spatial.affine_matmul(aff_left, aff)
        exp = nonlin.iexp if backward else nonlin.exp
        phi = exp(recompute=False, cache_result=True)
        aff_left = spatial.affine_matmul(aff_left, nonlin.affine)
        aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right)
        if _almost_identity(aff_right) and nonlin.shape == shape:
            phi = nonlin.add_identity(phi)
        else:
            tmp = spatial.affine_grid(aff_right, shape)
            phi = regutils.smart_pull_grid(phi, tmp).add_(tmp)
            del tmp
        if not _almost_identity(aff_left):
            phi = spatial.affine_matvec(aff_left, phi)
    else:
        # no nonlin: single affine even if position == 'symmetric'
        if reslice:
            aff = spatial.affine_matmul(aff, aff_right)
            aff = spatial.affine_matmul(aff_left, aff)
            phi = spatial.affine_grid(aff, shape)
        else:
            phi = None

    # warp image
    if phi is not None:
        warped = image.pull(phi)
    else:
        warped = image.dat

    # write to disk
    if len(warped) == 1:
        warped = warped[0]
    else:
        warped = utils.movedim(warped, 0, -1)
    return warped
Exemplo n.º 2
0
def load_and_pull(volume, aff, shape, dtype=None, device=None):
    """

    Parameters
    ----------
    volume : Volume3D
    aff : (D+1,D+1) tensor
    shape : (D,) tuple

    Returns
    -------
    dat : tensor

    """

    backend = dict(dtype=dtype or aff.dtype, device=device or aff.device)
    aff = aff.to(**backend)
    identity = torch.eye(aff.shape[-1], **backend)
    fdata = volume.fdata(cache=False, **backend)
    inshape = fdata.shape
    inaff = volume.affine.to(**backend)
    aff = core.linalg.lmdiv(inaff, aff)
    if torch.allclose(aff, identity) and tuple(shape) == tuple(inshape):
        return fdata
    else:
        grid = spatial.affine_grid(aff, shape)
        return spatial.grid_pull(fdata[None, None, ...], grid[None, ...])[0, 0]
Exemplo n.º 3
0
 def build_from_target(target):
     """Compose all transformations, starting from the final orientation"""
     grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
     for trf in reversed(options.transformations):
         if isinstance(trf, Linear):
             grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
         else:
             mat = trf.affine.to(**backend)
             if trf.inv:
                 vx0 = spatial.voxel_size(mat)
                 vx1 = spatial.voxel_size(target.affine.to(**backend))
                 factor = vx0 / vx1
                 disp, mat = spatial.resize_grid(trf.dat[None],
                                                 factor,
                                                 affine=mat,
                                                 interpolation=trf.spline)
                 disp = spatial.grid_inv(disp[0], type='disp')
                 order = 1
             else:
                 disp = trf.dat
                 order = trf.spline
             imat = spatial.affine_inv(mat)
             grid = spatial.affine_matvec(imat, grid)
             grid += helpers.pull_grid(disp, grid, interpolation=order)
             grid = spatial.affine_matvec(mat, grid)
     return grid
Exemplo n.º 4
0
def _msk_fov(dat, mat, mat0, dim0):
    """Mask field-of-view (FOV) of image data according to other image's
    FOV.

    Parameters
    ----------
    dat : (X, Y, Z), tensor
        Image data.
    mat : (4, 4), tensor
        Image's affine.
    mat0 : (4, 4), tensor
        Other image's affine.
    dim0 : (3, ), list/tuple
        Other image's dimensions.

    Returns
    -------
    dat : (X, Y, Z), tensor
        Masked image data.

    """
    dim = dat.shape
    M = lmdiv(mat0, mat)  # mat0\mat1
    grid = affine_grid(M, dim)
    msk = (grid[..., 0] >= 1) & (grid[..., 0] <= dim0[0]) & \
          (grid[..., 1] >= 1) & (grid[..., 1] <= dim0[1]) & \
          (grid[..., 2] >= 1) & (grid[..., 2] <= dim0[2])
    dat[~msk] = 0

    return dat
Exemplo n.º 5
0
    def forward(self, affine, **overload):
        """

        Parameters
        ----------
        affine : (batch, ndim[+1], ndim+1) tensor
            Affine matrix
        overload : dict
            All parameters of the module can be overridden at call time.

        Returns
        -------
        grid : (batch, *shape, ndim) tensor
            Dense transformation grid

        """

        nb_dim = affine.shape[-1] - 1
        info = {'dtype': affine.dtype, 'device': affine.device}
        shape = make_list(overload.get('shape', self.shape), nb_dim)
        shift = overload.get('shift', self.shift)

        if shift:
            affine_shift = torch.cat((
                torch.eye(nb_dim, **info),
                -torch.as_tensor(shape, **info)[:, None]/2),
                dim=1)
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

        grid = spatial.affine_grid(affine, shape)
        return grid
Exemplo n.º 6
0
def _init_y_dat(x, y, sett):
    """ Make initial guesses of reconstucted image(s) using b-spline interpolation,
        with averaging if more than one observation per channel.
    """
    dim_y = y[0].dim
    mat_y = y[0].mat
    for c in range(len(x)):
        dat_y = torch.zeros(dim_y, dtype=torch.float32, device=sett.device)
        num_x = len(x[c])
        sm    = torch.zeros_like(dat_y)
        for n in range(num_x):
            # Get image data
            dat = x[c][n].dat[None, None, ...]
            # Make output grid
            mat = mat_y.solve(x[c][n].mat)[0]  # mat_x\mat_y
            grid = affine_grid(mat.type(dat.dtype), dim_y)
            # Do resampling
            mn = torch.min(dat)
            mx = torch.max(dat)
            dat = grid_pull(dat, grid[None, ...],
                bound='zero', extrapolate=False, interpolation=1)
            dat[dat < mn] = mn
            dat[dat > mx] = mx
            sm = sm + (dat[0, 0, ...].round() != 0)
            dat_y = dat_y + dat[0, 0, ...]
        sm[sm == 0] = 1
        y[c].dat = dat_y / sm

    return y
Exemplo n.º 7
0
def _resample_inplane(x, sett):
    """Force in-plane resolution of observed data to be greater or equal to recon vx.
    """
    if sett.force_inplane_res and sett.max_iter > 0:
        I = torch.eye(4, device=sett.device, dtype=torch.float64)
        for c in range(len(x)):
            for n in range(len(x[c])):
                # get image data
                dat = x[c][n].dat[None, None, ...]
                mat_x = x[c][n].mat
                dim_x = torch.as_tensor(x[c][n].dim, device=sett.device, dtype=torch.float64)
                vx_x = voxel_size(mat_x)
                # make grid
                D = I.clone()
                for i in range(3):
                    D[i, i] = sett.vx / vx_x[i]
                    if D[i, i] < 1.0:
                        D[i, i] = 1
                if float((I - D).abs().sum()) < 1e-4:
                    continue
                mat_x = mat_x.matmul(D)
                dim_x = D[:3, :3].inverse().mm(dim_x[:, None]).floor().squeeze().cpu().int().tolist()
                grid = affine_grid(D.type(dat.dtype), dim_x)
                # resample
                dat = grid_pull(dat, grid[None, ...], bound='zero', extrapolate=False, interpolation=0)
                # do label
                if x[c][n].label is not None:
                    x[c][n].label[0] = _warp_label(x[c][n].label[0], grid)
                # assign
                x[c][n].dat = dat[0, 0, ...]
                x[c][n].mat = mat_x
                x[c][n].dim = dim_x

    return x
Exemplo n.º 8
0
    def forward(self, affine, shape=None):
        """

        Parameters
        ----------
        affine : (batch, ndim[+1], ndim+1) tensor
            Affine matrix
        shape : sequence[int], default=self.shape

        Returns
        -------
        grid : (batch, *shape, ndim) tensor
            Dense transformation grid

        """

        nb_dim = affine.shape[-1] - 1
        backend = {'dtype': affine.dtype, 'device': affine.device}
        shape = shape or self.shape

        if self.shift:
            affine_shift = torch.eye(nb_dim + 1, **backend)
            affine_shift[:nb_dim, -1] = torch.as_tensor(shape, **backend)
            affine_shift[:nb_dim, -1].sub(1).div(2).neg()
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

        grid = spatial.affine_grid(affine, shape)
        return grid
Exemplo n.º 9
0
    def forward(self, source, target, source_affine=None, target_affine=None):
        """

        Parameters
        ----------
        source : (sX, sY, sZ) tensor or str
        target : (tX, tY, tZ) tensor or str
        source_affine : (4, 4) tensor, optional
        target_affine : (4, 4) tensor, optional

        Returns
        -------
        warped : (tX, tY, tZ) tensor
            Source warped to target
        velocity : (vX, vY, vZ, 3) tensor
            Stationary velocity field
        affine : (4, 4) tensor, optional
            Affine of the velocity space

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        source, source_affine, source_orig, source_affine_orig \
            = self.load(source, source_affine)
        target, target_affine, target_orig, target_affine_orig \
            = self.load(target, target_affine)
        source = spatial.reslice(source, source_affine, target_affine,
                                 target.shape)
        if self.verbose:
            print('done.', flush=True)
            print('Registering... ', end='', flush=True)
        source = source[None, None]
        target = target[None, None]
        warped, vel, grid = super().forward(source, target)
        if self.verbose:
            print('done.', flush=True)
        del source, target, warped
        vel = vel[0]
        grid = grid[0]
        grid -= spatial.identity_grid(grid.shape[:-1],
                                      dtype=grid.dtype,
                                      device=grid.device)
        right_affine = target_affine.inverse() @ target_affine_orig
        right_affine = spatial.affine_grid(right_affine, target_orig.shape)
        grid = spatial.grid_pull(utils.movedim(grid, -1, 0),
                                 right_affine,
                                 bound='nearest',
                                 extrapolate=True)
        grid = utils.movedim(grid, 0, -1).add_(right_affine)
        left_affine = source_affine_orig.inverse() @ target_affine
        grid = spatial.affine_matvec(left_affine, grid)
        warped = spatial.grid_pull(source_orig, grid)
        return warped, vel, target_affine
Exemplo n.º 10
0
def _init_y_label(x, y, sett):
    """Make initial guess of labels.
    """
    dim_y = y[0].dim
    mat_y = y[0].mat
    for c in range(len(x)):
        n = 0
        if x[c][n].label is not None:
            # Make output grid
            mat = mat_y.solve(x[c][n].mat)[0]  # mat_x\mat_y
            grid = affine_grid(mat.type(x[c][n].dat.dtype), dim_y)
            # Do resampling
            y[c].label = _warp_label(x[c][n].label[0], grid)

    return y
Exemplo n.º 11
0
def _reslice_dat_3d(dat,
                    affine,
                    dim_out,
                    interpolation='linear',
                    bound='zero',
                    extrapolate=False):
    """Reslice 3D image data.

    Parameters
    ----------
    dat : (Xi, Yi, Zi), tensor_like
        Input image data.
    affine : (4, 4), tensor_like
        Affine transformation that maps from voxels in output image to
        voxels in input image.
    dim_out : (Xo, Yo, Zo), list or tuple
        Output image dimensions.
    interpolation : str, default='linear'
        Interpolation order.
    bound : str, default='zero'
        Boundary condition.
    extrapolate : bool, default=False
        Extrapolate out-of-bounds data.

    Returns
    -------
    dat : (dim_out), tensor_like
        Resliced image data.

    """
    if len(dat.shape) != 3:
        raise ValueError('Input error: len(dat.shape) != 3')

    grid = affine_grid(affine, dim_out).type(dat.dtype)
    grid = grid[None, ...]
    dat = dat[None, None, ...]
    dat = grid_pull(dat,
                    grid,
                    bound=bound,
                    interpolation=interpolation,
                    extrapolate=extrapolate)
    dat = dat[0, 0, ...]

    return dat
Exemplo n.º 12
0
 def slice_to(self, stack, cache_result=False, recompute=True):
     aff = self.exp(cache_result=cache_result, recompute=recompute)
     if recompute or not hasattr(self, '_sliced'):
         aff = spatial.affine_matmul(aff, self.affine)
         aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout)
         aff = spatial.affine_lmdiv(aff_reorient, aff)
         aff = spatial.affine_grid(aff, self.shape)
         sliced = spatial.grid_pull(self.dat, aff, bound=self.bound,
                                    extrapolate=self.extrapolate)
         fwhm = [0] * self.dim
         fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1]
         sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound)
         slices = []
         for stack_slice in stack.slices:
             aff = spatial.affine_matmul(stack.affine, )
             aff = spatial.affine_lmdiv(aff_reorient, )
     if cache_result:
         self._sliced = sliced
     return sliced
Exemplo n.º 13
0
    def propagate_grad(self, g, h, moving, phi, left=None, right=None, inv=False):
        """Convert derivatives wrt warped image in loss space to
        to derivatives wrt parameters
        parameters:
            g (tensor) : gradient wrt warped image
            h (tensor) : hessian wrt warped image
            moving (Image) : moving image
            phi (tensor) : dense (exponentiated) displacement field
            left (matrix) : left affine
            right (matrix) : right affine
            inv (bool) : whether we're in a backward symmetric pass
        returns:
            g (tensor) : pushed gradient
            h (tensor) : pushed hessian
            gmu (tensor) : rotated spatial gradients
        """
        if inv:
            g = g.neg_()

        # build bits of warp
        dim = phi.shape[-1]
        fixed_shape = g.shape[-dim:]
        moving_shape = moving.shape

        # differentiate wrt δ in: Left o Phi o (Id + δ) o Right
        # we'll then propagate them through Phi by scaling and squaring
        if right is not None:
            right = spatial.affine_grid(right, fixed_shape)
        g = regutils.smart_push(g, right, shape=self.shape)
        h = regutils.smart_push(h, right, shape=self.shape)
        del right

        phi_left = spatial.identity_grid(self.shape, **utils.backend(phi))
        phi_left += phi
        if left is not None:
            phi_left = spatial.affine_matvec(left, phi_left)
        mugrad = moving.pull_grad(phi_left, rotate=False)
        del phi_left

        mugrad = _rotate_grad(mugrad, left, phi)

        return g, h, mugrad
Exemplo n.º 14
0
def _crop_y(y, sett):
    """ Crop output images FOV to a fixed dimension

    Args:
        y (_output()): _output data.

    Returns:
        y (_output()): Cropped output data.

    """
    if not sett.crop:
        return y
    device = sett.device
    # Output image information
    mat_y = y[0].mat
    vx_y = voxel_size(mat_y)
    # Define cropped FOV
    mat_mu, dim_mu = _bb_atlas('atlas_t1',
        fov=sett.fov, dtype=torch.float64, device=device)
    # Modulate atlas with voxel size
    mat_vx = torch.diag(torch.cat((
        vx_y, torch.ones(1, dtype=torch.float64, device=device))))
    mat_mu = mat_mu.mm(mat_vx)
    dim_mu = mat_vx[:3, :3].inverse().mm(dim_mu[:, None]).floor().squeeze()
    # Make output grid
    M = mat_mu.solve(mat_y)[0].type(y[0].dat.dtype)
    grid = affine_grid(M, dim_mu)[None, ...]
    # Crop
    for c in range(len(y)):
        y[c].dat = grid_pull(y[c].dat[None, None, ...], grid,
                             bound='zero', extrapolate=False,
                             interpolation=0)[0, 0, ...]
        # Do labels?
        if y[c].label is not None:
            y[c].label = grid_pull(y[c].label[None, None, ...], grid,
                                   bound='zero', extrapolate=False,
                                   interpolation=0)[0, 0, ...]
        y[c].mat = mat_mu
        y[c].dim = tuple(dim_mu.int().tolist())

    return y
Exemplo n.º 15
0
    def compute_grid(self, mat_native, dim_native):
        """Computes resampling grid for pulling/pushing from/to common space.

        Parameters
        ----------
        mat_native : (1, dim + 1, dim + 1) tensor
            Native image affine matrix.
        dim_native : [3, ] sequence
            Native image dimensions.

        Returns
        ----------
        grid : (batch, *spatial, dim) tensor
            Resampling grid.

        """
        self.mean_mat = self.mean_mat.type(mat_native.dtype).to(
            mat_native.device)
        mat = mat_native.solve(self.mean_mat)[0]
        grid = spatial.affine_grid(mat, dim_native)

        return grid
Exemplo n.º 16
0
def smart_grid(aff, shape, inshape=None, force=False):
    """Generate a sampling grid iff it is not the identity.

    Parameters
    ----------
    aff : (D+1, D+1) tensor
        Affine transformation matrix (voxels to voxels)
    shape : (D,) tuple[int]
        Output shape
    inshape : (D,) tuple[int], optional
        Input shape

    Returns
    -------
    grid : (*shape, D) tensor or None
        Sampling grid

    """
    backend = dict(dtype=aff.dtype, device=aff.device)
    identity = torch.eye(aff.shape[-1], **backend)
    inshape = inshape or shape
    if not force and torch.allclose(aff, identity) and shape == inshape:
        return None
    return spatial.affine_grid(aff, shape)
Exemplo n.º 17
0
def fit(x, y, sett):
    """ Fit model.

        This runs the iterative denoising/super-resolution algorithm and,
        at the end, writes the reconstructed images to disk. If the maximum number
        of iterations are set to zero, the initial guesses of the reconstructed
        images will be written to disk (acquired with b-spline interpolation), no
        denoising/super-resolution will be applied.

    Returns:
        dat_y (torch.tensor): Reconstructed image data as float32, (dim_y, C).
        mat_y (torch.tensor): Reconstructed affine matrix, (4, 4).
        pth_y ([str, ...]): Paths to reconstructed images.
        R (torch.tensor): Rigid matrices (N, 4, 4).
        label (torch.tensor): Reconstructed label image, (dim_y).
        pth_label str: Path to reconstructed label image.

    """
    with torch.no_grad():
        # Total number of observations
        N = sum([len(xn) for xn in x])

        # Sanity check scaling parameter
        if not isinstance(sett.reg_scl, torch.Tensor):
            sett.reg_scl = torch.tensor(sett.reg_scl,
                                        dtype=torch.float32,
                                        device=sett.device)
            sett.reg_scl = sett.reg_scl.reshape(1)

        # Defines a coarse-to-fine scaling of regularisation
        sett = _get_sched(N, sett)

        # For visualisation
        fig_ax_nll = None
        fig_ax_jtv = None

        # Scale lambda
        cnt_scl = 0
        for c in range(len(x)):
            y[c].lam = sett.reg_scl[cnt_scl] * y[c].lam0

        if sett.max_iter > 0:
            # Get ADMM step-size
            rho = _step_size(x, y, sett, verbose=True)
            # Get ADMM variables
            z, w = _admm_aux(y, sett)

        # ----------
        # ITERATE:
        # Updates model in an alternating fashion, until a convergence threshold is met
        # on the model negative log-likelihood.
        # ----------
        obj = torch.zeros(sett.max_iter,
                          3,
                          dtype=torch.float64,
                          device=sett.device)
        tmp = torch.zeros_like(
            y[0].dat)  # for holding rhs in y-update, and jtv in u-update
        t_iter = timer() if sett.do_print else 0
        cnt_scl_iter = 0  # To ensure we do, at least, a fixed number of iterations for each scale
        for n_iter in range(sett.max_iter):

            if n_iter == 0:
                t00 = _print_info('fit-start', sett, len(x), N)  # PRINT

            # ----------
            # UPDATE: image
            # ----------
            y, z, w, tmp, obj = _update_admm(x, y, z, w, rho, tmp, obj, n_iter,
                                             sett)

            # Show JTV
            if sett.show_jtv:
                fig_ax_jtv = show_slices(img=tmp,
                                         fig_ax=fig_ax_jtv,
                                         title='JTV',
                                         cmap='coolwarm',
                                         fig_num=98)

            # ----------
            # Check convergence
            # ----------
            if sett.plot_conv:  # Plot algorithm convergence
                fig_ax_nll = plot_convergence(
                    vals=obj[:n_iter + 1, :],
                    fig_ax=fig_ax_nll,
                    fig_num=99,
                    legend=['-ln(p(y|x))', '-ln(p(x|y))', '-ln(p(y))'])
            gain = get_gain(obj[:n_iter + 1, 0], monotonicity='decreasing')
            t_iter = _print_info('fit-ll', sett, n_iter, obj[n_iter, :], gain,
                                 t_iter)
            # Converged?
            if cnt_scl >= (sett.reg_scl.numel() - 1) and cnt_scl_iter > 20 \
                and ((gain.abs() < sett.tolerance) or (n_iter >= (sett.max_iter - 1))):
                countdown0 -= 1
                if countdown0 == 0:
                    _ = _print_info('fit-finish', sett, t00, n_iter)
                    break  # Finished
            else:
                countdown0 = 6

            # ----------
            # UPDATE: even/odd scaling
            # ----------
            if sett.scaling:

                t0 = _print_info('fit-update', sett, 's', n_iter)  # PRINT
                # Do update
                x, _ = _update_scaling(x,
                                       y,
                                       sett,
                                       max_niter_gn=1,
                                       num_linesearch=6,
                                       verbose=0)
                _ = _print_info('fit-done', sett, t0)  # PRINT
                # Print parameter estimates
                _ = _print_info('scl-param', sett, x, t0)

            # ----------
            # UPDATE: rigid_q
            # ----------
            if sett.unified_rigid and n_iter > 0 \
                and (n_iter % sett.rigid_mod) == 0:

                t0 = _print_info('fit-update', sett, 'q', n_iter)  # PRINT
                x, _ = _update_rigid(x,
                                     y,
                                     sett,
                                     mean_correct=False,
                                     max_niter_gn=1,
                                     num_linesearch=6,
                                     verbose=0,
                                     samp=sett.rigid_samp)
                _ = _print_info('fit-done', sett, t0)  # PRINT
                # Print parameter estimates
                _ = _print_info('reg-param', sett, x, t0)

            # ----------
            # Coarse-to-fine scaling of regularisation
            # ----------
            if cnt_scl + 1 < len(sett.reg_scl) and cnt_scl_iter > 16 and\
                    gain.abs() < 1e-3:
                countdown1 -= 1
                if countdown1 == 0:
                    cnt_scl_iter = 0
                    cnt_scl += 1
                    # Coarse-to-fine scaling of lambda
                    for c in range(len(x)):
                        y[c].lam = sett.reg_scl[cnt_scl] * y[c].lam0
                    # Also update ADMM step-size
                    rho = _step_size(x, y, sett)
            else:
                countdown1 = 6

            cnt_scl_iter += 1

        # ----------
        # Some post-processing
        # ----------
        if sett.clean_fov:
            # Zero outside FOV in reconstructed data
            for c in range(len(x)):
                msk_fov = torch.ones(y[c].dim,
                                     dtype=torch.bool,
                                     device=sett.device)
                for n in range(len(x[c])):
                    # Map to voxels in low-res image
                    M = x[c][n].po.rigid.mm(x[c][n].mat).solve(
                        y[c].mat)[0].inverse()
                    grid = affine_grid(M.type(x[c][n].dat.dtype),
                                       y[c].dim)[None, ...]
                    # Mask of low-res image FOV projected into high-res space
                    msk_fov = msk_fov & \
                              (grid[0, ..., 0] >= 1) & (grid[0, ..., 0] <= x[c][n].dim[0]) & \
                              (grid[0, ..., 1] >= 1) & (grid[0, ..., 1] <= x[c][n].dim[1]) & \
                              (grid[0, ..., 2] >= 1) & (grid[0, ..., 2] <= x[c][n].dim[2])
                    # if x[c][n].ct:
                    #     # Resample low-res image into high-res space
                    #     dat_c = grid_pull(x[c][n].dat[None, None, ...],
                    #                       grid, bound=sett.bound,
                    #                       extrapolate=False,
                    #                       interpolation=sett.interpolation)[0, 0, ...]
                    #     # Set voxels inside the FOV that are positive in the
                    #     # low-res data but negative in the high-res, to the
                    #     # their original values
                    #     msk = msk_fov & (dat_c >= 0) & (y[c].dat < 0)
                    #     y[c].dat[msk] = tmp[msk]
                # Zero voxels outside projected FOV
                y[c].dat[~msk_fov] = 0.0

        # # Possibly crop reconstructed data
        # y = _crop_y(y, sett)

        # ----------
        # Get rigid matrices
        # ----------
        R = torch.zeros((N, 4, 4), device=sett.device, dtype=torch.float64)
        cnt = 0
        for c in range(len(x)):
            for n in range(len(x[c])):
                R[cnt, ...] = _expm(x[c][n].rigid_q, sett.rigid_basis)
                cnt += 1

        # ----------
        # Possibly write reconstruction results to disk
        # ----------
        dat_y, pth_y, label, pth_label = _write_data(x, y, sett, jtv=tmp)

        return dat_y, y[0].mat, pth_y, R, label, pth_label
Exemplo n.º 18
0
    def do_affine(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is not None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff_pos = self.affine.position[0].lower()
        if any(loss.backward for loss in self.losses):
            aff0, iaff0, gaff0, igaff0 = \
                self.affine.exp2(logaff0, grad=True,
                                 cache_result=not in_line_search)
            phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False)
        else:
            iaff0, igaff0, iphi0 = None, None, None
            aff0, gaff0 = self.affine.exp(logaff0, grad=True,
                                          cache_result=not in_line_search)
            phi0 = self.nonlin.exp(cache_result=True, recompute=False)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, gaff00 = iphi0, iaff0, igaff0
            else:
                phi00, aff00, gaff00 = phi0, aff0, gaff0

            # ----------------------------------------------------------
            # build left and right affine matrices
            # ----------------------------------------------------------
            aff_right, gaff_right = fixed.affine, None
            if aff_pos in 'fs':
                gaff_right = gaff00 @ aff_right
                gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right)
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left, gaff_left = self.nonlin.affine, None
            if aff_pos in 'ms':
                gaff_left = gaff00 @ aff_left
                gaff_left = linalg.lmdiv(moving.affine, gaff_left)
                aff_left = aff00 @ aff_left
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                right = None
                phi = spatial.add_identity_grid(phi00)
            else:
                right = spatial.affine_grid(aff_right, fixed.shape)
                phi = regutils.smart_pull_grid(phi00, right)
                phi += right
            phi_right = phi
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                left = None
            else:
                left = spatial.affine_grid(aff_left, self.nonlin.shape)
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            if grad or hess:
                g0, g = g, None
                h0, h = h, None
                if aff_pos in 'ms':
                    g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape)
                    h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape)
                    mugrad = moving.pull_grad(left, rotate=False)
                    g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left)
                    g, h = g_left, h_left
                if aff_pos in 'fs':
                    g_right, h_right = g0, h0
                    mugrad = moving.pull_grad(phi, rotate=False)
                    jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False)
                    jac = torch.matmul(aff_left[:-1, :-1], jac)
                    mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad)
                    g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right)
                    g = g_right if g is None else g.add_(g_right)
                    h = h_right if h is None else h.add_(h_right)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        sumloss += self.llv
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            ll = sumloss.item()
            llv = self.llv
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Exemplo n.º 19
0
    def do_vel(self, vel, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the nonlinear component"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        if self.affine:
            aff0, iaff0 = self.affine.exp2(cache_result=True, recompute=False)
            aff_pos = self.affine.position[0].lower()
        else:
            aff_pos = 'x'
            aff0 = iaff0 = torch.eye(self.nonlin.dim + 1)
        vel0 = vel
        if any(loss.backward for loss in self.losses):
            phi0, iphi0 = self.nonlin.exp2(vel0,
                                           recompute=True,
                                           cache_result=not in_line_search)
            ivel0 = -vel0
        else:
            phi0 = self.nonlin.exp(vel0,
                                   recompute=True,
                                   cache_result=not in_line_search)
            iphi0 = ivel0 = None
        aff0 = aff0.to(phi0)
        iaff0 = iaff0.to(phi0)

        # ==============================================================
        #                     ACCUMULATE DERIVATIVES
        # ==============================================================

        has_printed = False
        for loss in self.losses:

            # ==========================================================
            #                     ONE LOSS COMPONENT
            # ==========================================================
            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, vel00 = iphi0, iaff0, ivel0
            else:
                phi00, aff00, vel00 = phi0, aff0, vel0

            # ----------------------------------------------------------
            # build left and right affine
            # ----------------------------------------------------------
            aff_right = fixed.affine
            if aff_pos in 'fs':  # affine position: fixed or symmetric
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left = self.nonlin.affine
            if aff_pos in 'ms':  # affine position: moving or symmetric
                aff_left = aff00 @ self.nonlin.affine
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                aff_right = None
                phi = spatial.add_identity_grid(phi00)
                disp = phi00
            else:
                phi = spatial.affine_grid(aff_right, fixed.shape)
                disp = regutils.smart_pull_grid(phi00, phi)
                phi += disp
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                aff_left = None
            else:
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, disp, dim=fixed.dim,
                             title=f'(nonlin) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt phi
            # ----------------------------------------------------------
            if grad or hess:

                g, h, mugrad = self.nonlin.propagate_grad(
                    g, h, moving, phi00, aff_left, aff_right,
                    inv=loss.backward)
                g = regutils.jg(mugrad, g)
                h = regutils.jhj(mugrad, h)
                if isinstance(self.nonlin, SVFModel):
                    # propagate backward by scaling and squaring
                    g, h = spatial.exp_backward(vel00, g, h,
                                                steps=self.nonlin.steps)

                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # ==============================================================
        #                       REGULARIZATION
        # ==============================================================
        vgrad = self.nonlin.regulariser(vel0)
        llv = 0.5 * vel0.flatten().dot(vgrad.flatten())
        if grad:
            sumgrad += vgrad
        del vgrad

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += llv
        sumloss += self.lla
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            llv = llv.item()
            ll = sumloss.item()
            lla = self.lla
            if in_line_search:
                line = '(search) | '
            else:
                line = '(nonlin) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.llv = llv
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Exemplo n.º 20
0
def _update_scaling(x, y, sett, max_niter_gn=1, num_linesearch=4, verbose=0):
    """ Updates an even/odd slice scaling parameter using Gauss-Newton
        optimisation.

    Args:
        verbose (bool, optional): Verbose for testing, defaults to False.

    Returns:
        sll (torch.tensor): Log-likelihood.

    """
    # Update rigid parameters, for all input images
    sll = torch.tensor(0, device=sett.device, dtype=torch.float64)
    ll = torch.tensor(0, device=sett.device, dtype=torch.float64)
    for c in range(len(x)):  # Loop over channels
        for n_x in range(len(x[c])):  # Loop over repeats
            if x[c][n_x].ct:
                # Do not optimise scaling for CT data
                continue
            # Parameters
            dim_thick = x[c][n_x].po.dim_thick
            tau = x[c][n_x].tau
            scl = x[c][n_x].po.scl
            smo_ker = x[c][n_x].po.smo_ker
            dim_thick = x[c][n_x].po.dim_thick
            ratio = x[c][n_x].po.ratio
            dim = x[c][n_x].po.dim_yx
            mat_yx = x[c][n_x].po.mat_yx
            mat_y = x[c][n_x].po.mat_y
            rigid = _expm(x[c][n_x].rigid_q, sett.rigid_basis)
            mat = rigid.mm(mat_yx).solve(mat_y)[0]  # mat_y\rigid*mat_yx
            # Observed data
            dat_x = x[c][n_x].dat
            msk = dat_x != 0
            # Get even/odd data
            xo = _even_odd(dat_x, 'odd', dim_thick)
            mo = _even_odd(msk, 'odd', dim_thick)
            xo = xo[mo]
            xe = _even_odd(dat_x, 'even', dim_thick)
            me = _even_odd(msk, 'even', dim_thick)
            xe = xe[me]
            # Get reconstruction (without scaling)
            grid = affine_grid(mat.type(torch.float32), dim, jitter=False)
            dat_y = grid_pull(y[c].dat[None, None, ...],
                              grid[None, ...],
                              bound=sett.bound,
                              interpolation=sett.interpolation,
                              extrapolate=False)
            dat_y = F.conv3d(dat_y, smo_ker, stride=ratio)[0, 0, ...]
            # Apply scaling
            dat_y = _apply_scaling(dat_y, scl, dim_thick)

            for n_gn in range(
                    max_niter_gn):  # Loop over Gauss-Newton iterations
                # Log-likelihood
                ll = 0.5 * tau * torch.sum(
                    (dat_x[msk] - dat_y[msk])**2, dtype=torch.float64)

                if verbose >= 2:  # Show images
                    show_slices(torch.stack((dat_x, dat_y, (dat_x - dat_y)**2),
                                            3),
                                fig_num=666,
                                colorbar=False,
                                flip=False)

                # Get even/odd data
                yo = _even_odd(dat_y, 'odd', dim_thick)
                yo = yo[mo]
                ye = _even_odd(dat_y, 'even', dim_thick)
                ye = ye[me]

                # Gradient
                gr = tau * (torch.sum(ye * (xe - ye), dtype=torch.float64) -
                            torch.sum(yo * (xo - yo), dtype=torch.float64))

                # Hessian
                Hes = tau * (torch.sum(ye**2, dtype=torch.float64) +
                             torch.sum(yo**2, dtype=torch.float64))

                # Compute Gauss-Newton update step
                Update = gr / Hes

                # Do update..
                old_scl = scl.clone()
                old_ll = ll.clone()
                armijo = torch.tensor(1.0,
                                      device=sett.device,
                                      dtype=old_scl.dtype)
                if num_linesearch == 0:
                    # ..without a line-search
                    scl = old_scl - armijo * Update
                    if verbose >= 1:
                        print('c={}, n={}, gn={} | exp(s)={}'.format(
                            c, n_x, n_gn, round(scl.exp(), 5)))
                else:
                    # ..using a line-search
                    for n_ls in range(num_linesearch):
                        # Take step
                        scl = old_scl - armijo * Update
                        # Apply scaling
                        dat_y = _apply_scaling(dat_y, scl - old_scl, dim_thick)
                        # Compute matching term
                        ll = 0.5 * tau * torch.sum(
                            (dat_x[msk] - dat_y[msk])**2, dtype=torch.float64)

                        if verbose >= 2:  # Show images
                            show_slices(torch.stack(
                                (dat_x, dat_y, (dat_x - dat_y)**2), 3),
                                        fig_num=666,
                                        colorbar=False,
                                        flip=False)

                        # Matching improved?
                        if ll < old_ll:
                            # Better fit!
                            if verbose >= 1:
                                print(
                                    'c={}, n={}, gn={}, ls={} | :) ll={:0.2f}, ll-oll={:0.2f} | exp(s)={} armijo={}'
                                    .format(c, n_x, n_gn, n_ls, ll,
                                            ll - old_ll, round(scl.exp(), 5),
                                            round(armijo, 4)))
                            break
                        else:
                            # Reset parameters
                            scl = old_scl
                            ll = old_ll
                            armijo *= 0.5
                            if verbose >= 1 and n_ls == num_linesearch - 1:
                                print(
                                    'c={}, n={}, gn={}, ls={} | :( ll={:0.2f}, ll-oll={:0.2f} | exp(s)={} armijo={}'
                                    .format(c, n_x, n_gn, n_ls,
                                            ll, ll - old_ll,
                                            round(old_scl.exp(), 5),
                                            round(armijo, 4)))
            # Update scaling in projection operator
            x[c][n_x].po.scl = scl
            # Accumulate neg log-lik
            sll += ll

    return x, sll
Exemplo n.º 21
0
    def do_affine_only(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff0, iaff0, gaff0, igaff0 = self.affine.exp2(logaff0, grad=True)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                aff00, gaff00 = iaff0, igaff0
            else:
                aff00, gaff00 = aff0, gaff0

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            aff = aff00 @ fixed.affine
            aff = linalg.lmdiv(moving.affine, aff)
            gaff = gaff00 @ fixed.affine
            gaff = linalg.lmdiv(moving.affine, gaff)
            phi = spatial.affine_grid(aff, fixed.shape)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.preview
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.preview, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            # compose with spatial gradients
            if grad or hess:
                mugrad = moving.pull_grad(phi, rotate=False)
                g, h = compose_grad(g, h, mugrad, gaff)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        lla = lla
        ll = sumloss.item()
        self.loss_value = ll
        if self.verbose and (self.verbose > 1 or not in_line_search):
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Exemplo n.º 22
0
def _update_rigid_channel(xc,
                          yc,
                          sett,
                          max_niter_gn=1,
                          num_linesearch=4,
                          verbose=0,
                          samp=3,
                          c=1):
    """ Updates the rigid parameters for all images of one channel.

    Args:
        c (int): Channel index.
        rigid_basis (torch.tensor)
        max_niter_gn (int, optional): Max Gauss-Newton iterations, defaults to 1.
        num_linesearch (int, optional): Max line-search iterations, defaults to 4.
        verbose (bool, optional): Show registration results, defaults to 0.
            0: No verbose
            1: Print convergence info to console
            2: Plot registration results using matplotlib
        samp (int, optional): Sub-sample data, defaults to 3.

    Returns:
        sll (torch.tensor): Log-likelihood.

    """
    # Parameters
    device = yc.dat.device
    method = sett.method
    num_q = sett.rigid_basis.shape[0]
    lkp = [[0, 3, 4], [3, 1, 5], [4, 5, 2]]
    one = torch.tensor(1.0, device=device, dtype=torch.float64)

    sll = torch.tensor(0, device=device, dtype=torch.float64)
    for n_x in range(len(xc)):  # Loop over repeats

        # Lowres image data
        dat_x = xc[n_x].dat[None, None, ...]
        # Parameters
        q = xc[n_x].rigid_q
        tau = xc[n_x].tau
        armijo = torch.tensor(1, device=device, dtype=q.dtype)
        po = _proj_info(xc[n_x].po.dim_y,
                        xc[n_x].po.mat_y,
                        xc[n_x].po.dim_x,
                        xc[n_x].po.mat_x,
                        rigid=xc[n_x].po.rigid,
                        prof_ip=sett.profile_ip,
                        prof_tp=sett.profile_tp,
                        gap=sett.gap,
                        device=device,
                        scl=xc[n_x].po.scl,
                        samp=samp)
        # Method
        if method == 'super-resolution':
            dim = po.dim_yx
            mat = po.mat_yx
        elif method == 'denoising':
            dim = po.dim_x
            mat = po.mat_x

        # Do sub-sampling?
        if samp > 0 and po.D_x is not None:
            # Lowres
            grid = affine_grid(po.D_x.type(torch.float32), po.dim_x)
            dat_x = grid_pull(xc[n_x].dat[None, None, ...],
                              grid[None, ...],
                              bound='zero',
                              extrapolate=False,
                              interpolation=0)[0, 0, ...]
            if n_x == 0 and po.D_y is not None:
                # Highres (only for superres)
                grid = affine_grid(po.D_y.type(dtype=torch.float32), po.dim_y)
                dat_y = grid_pull(yc.dat[None, None, ...],
                                  grid[None, ...],
                                  bound='zero',
                                  extrapolate=False,
                                  interpolation=0)
            else:
                dat_y = yc.dat[None, None, ...]
        else:
            dat_x = xc[n_x].dat
            dat_y = yc.dat[None, None, ...]

        # Pre-compute super-resolution Hessian (CtC)?
        CtC = None
        if method == 'super-resolution':
            CtC = F.conv3d(torch.ones((
                1,
                1,
            ) + dim,
                                      device=device,
                                      dtype=torch.float32),
                           po.smo_ker,
                           stride=po.ratio)
            CtC = F.conv_transpose3d(CtC, po.smo_ker, stride=po.ratio)[0, 0,
                                                                       ...]

        # Get identity grid
        id_x = identity_grid(dim,
                             dtype=torch.float32,
                             device=device,
                             jitter=False)

        for n_gn in range(max_niter_gn):  # Loop over Gauss-Newton iterations

            # Differentiate Rq w.r.t. q (store in d_rigid_q)
            rigid, d_rigid = _expm(q, sett.rigid_basis, grad_X=True)
            d_rigid = d_rigid.permute(
                (1, 2, 0))  # make compatible with old affine_basis
            d_rigid_q = torch.zeros(4,
                                    4,
                                    num_q,
                                    device=device,
                                    dtype=torch.float64)
            for i in range(num_q):
                d_rigid_q[:, :, i] = d_rigid[:, :, i].mm(mat).solve(
                    po.mat_y)[0]  # mat_y\d_rigid*mat

            # Compute gradient and Hessian
            gr = torch.zeros(num_q, 1, device=device, dtype=torch.float64)
            Hes = torch.zeros(num_q, num_q, device=device, dtype=torch.float64)

            # Compute matching-term part (log-likelihood)
            ll, gr_m, Hes_m = _rigid_match(dat_x,
                                           dat_y,
                                           po,
                                           tau,
                                           rigid,
                                           sett,
                                           diff=True,
                                           verbose=verbose,
                                           CtC=CtC)

            # Multiply with d_rigid_q (chain-rule)
            dAff = []
            for i in range(num_q):
                dAff.append([])
                for d in range(3):
                    dAff[i].append(d_rigid_q[d, 0, i] * id_x[:, :, :, 0] + \
                                   d_rigid_q[d, 1, i] * id_x[:, :, :, 1] + \
                                   d_rigid_q[d, 2, i] * id_x[:, :, :, 2] + \
                                   d_rigid_q[d, 3, i])

            # Add d_rigid_q to gradient
            for d in range(3):
                for i in range(num_q):
                    gr[i] += torch.sum(gr_m[:, :, :, d] * dAff[i][d],
                                       dtype=torch.float64)

            # Add d_rigid_q to Hessian
            for d1 in range(3):
                for d2 in range(3):
                    for i1 in range(num_q):
                        tmp1 = Hes_m[:, :, :, lkp[d1][d2]] * dAff[i1][d1]
                        for i2 in range(i1, num_q):
                            Hes[i1, i2] += torch.sum(tmp1 * dAff[i2][d2],
                                                     dtype=torch.float64)

            # Fill in missing triangle
            for i1 in range(num_q):
                for i2 in range(i1 + 1, num_q):
                    Hes[i2, i1] = Hes[i1, i2]

            # # Regularise diagonal of Hessian
            # Hes += 1e-5*Hes.diag().max()*torch.eye(num_q, dtype=Hes.dtype, device=device)

            # Compute Gauss-Newton update step
            Update = gr.solve(Hes)[0][:, 0]

            # Do update..
            old_ll = ll.clone()
            old_q = q.clone()
            old_rigid = rigid.clone()
            if num_linesearch == 0:
                # ..without a line-search
                q = old_q - armijo * Update
                rigid = _expm(q, sett.rigid_basis)
                if verbose >= 1:
                    print('c={}, n={}, gn={} | q={}'.format(
                        c, n_x, n_gn,
                        round(q, 7).tolist()))
            else:
                # ..using a line-search
                for n_ls in range(num_linesearch):
                    # Take step
                    q = old_q - armijo * Update
                    # Compute matching term
                    rigid = _expm(q, sett.rigid_basis)
                    ll = _rigid_match(dat_x,
                                      dat_y,
                                      po,
                                      tau,
                                      rigid,
                                      sett,
                                      verbose=verbose)[0]
                    # Matching improved?
                    if ll < old_ll:
                        # Better fit!
                        armijo = torch.min(1.25 * armijo, one)
                        if verbose >= 1:
                            print(
                                'c={}, n={}, gn={}, ls={} | :) ll={:0.2f}, ll-oll={:0.2f} | q={} armijo={}'
                                .format(c, n_x, n_gn, n_ls, ll, ll - old_ll,
                                        round(q, 7).tolist(), round(armijo,
                                                                    4)))
                        break
                    else:
                        # Reset parameters
                        ll = old_ll
                        q = old_q
                        rigid = old_rigid
                        armijo *= 0.5
                        if n_ls == num_linesearch - 1 and verbose >= 1:
                            print(
                                'c={}, n={}, gn={}, ls={} | :( ll={:0.2f}, ll-oll={:0.2f} | q={} armijo={}'
                                .format(c, n_x, n_gn, n_ls, ll, ll - old_ll,
                                        round(q, 7).tolist(), round(armijo,
                                                                    4)))
        # Assign
        xc[n_x].rigid_q = q
        xc[n_x].po.rigid = rigid
        # Accumulate neg log-lik
        sll += ll

    return xc, sll
Exemplo n.º 23
0
def _rigid_match(dat_x,
                 dat_y,
                 po,
                 tau,
                 rigid,
                 sett,
                 CtC=None,
                 diff=False,
                 verbose=0):
    """ Computes the rigid matching term, and its gradient and Hessian (if requested).

    Args:
        dat_x (torch.tensor): Observed data (X0, Y0, Z0).
        dat_y (torch.tensor): Reconstructed data (X1, Y1, Z1).
        po (ProjOp): Projection operator.
        tau (torch.tensor): Noice precision.
        CtC (torch.tensor, optional): CtC(ones), used for super-res gradient calculation.
            Defaults to None.
        rigid (torch.tensor): Rigid transformation matrix (4, 4).
        diff (bool, optional): Compute derivatives, defaults to False.
        verbose (bool, optional): Show registration results, defaults to 0.
            0: No verbose
            1: Print convergence info to console
            2: Plot registration results using matplotlib

    Returns:
        ll (torch.tensor): Log-likelihood.
        gr (torch.tensor): Gradient (dim_x, 3).
        Hes (torch.tensor): Hessian (dim_x, 6).

    """
    # Projection info
    mat_x = po.mat_x
    mat_y = po.mat_y
    mat_yx = po.mat_yx
    dim_x = po.dim_x
    dim_yx = po.dim_yx
    ratio = po.ratio
    smo_ker = po.smo_ker
    dim_thick = po.dim_thick
    scl = po.scl

    # Init output
    ll = None
    gr = None
    Hes = None

    if sett.method == 'super-resolution':
        extrapolate = False
        dim = dim_yx
        mat = mat_yx
    elif sett.method == 'denoising':
        extrapolate = False
        dim = dim_x
        mat = mat_x

    # Get grid
    mat = rigid.mm(mat).solve(mat_y)[0]  # mat_y\rigid*mat
    grid = affine_grid(mat.type(torch.float32), dim, jitter=False)

    # Warp y and compute spatial derivatives
    dat_yx = grid_pull(dat_y,
                       grid[None, ...],
                       bound=sett.bound,
                       extrapolate=extrapolate,
                       interpolation=sett.interpolation)[0, 0, ...]
    if sett.method == 'super-resolution':
        dat_yx = F.conv3d(dat_yx[None, None, ...], smo_ker, stride=ratio)[0, 0,
                                                                          ...]
        if scl != 0:
            dat_yx = _apply_scaling(dat_yx, scl, dim_thick)
    if diff:
        gr = grid_grad(dat_y,
                       grid[None, ...],
                       bound=sett.bound,
                       extrapolate=extrapolate,
                       interpolation=sett.interpolation)[0, 0, ...]

    if verbose >= 2:  # Show images
        show_slices(torch.stack((dat_x, dat_yx, (dat_x - dat_yx)**2), 3),
                    fig_num=666,
                    colorbar=False,
                    flip=False)

    # Double and mask
    msk = (dat_x != 0)

    # Compute matching term
    ll = 0.5 * tau * torch.sum(
        (dat_x[msk] - dat_yx[msk])**2, dtype=torch.float64)

    if diff:
        # Difference
        diff = dat_yx - dat_x
        msk = msk & (dat_yx != 0)
        diff[~msk] = 0
        # Hessian
        Hes = torch.zeros(dim + (6, ),
                          device=dat_x.device,
                          dtype=torch.float32)
        Hes[:, :, :, 0] = gr[:, :, :, 0] * gr[:, :, :, 0]
        Hes[:, :, :, 1] = gr[:, :, :, 1] * gr[:, :, :, 1]
        Hes[:, :, :, 2] = gr[:, :, :, 2] * gr[:, :, :, 2]
        Hes[:, :, :, 3] = gr[:, :, :, 0] * gr[:, :, :, 1]
        Hes[:, :, :, 4] = gr[:, :, :, 0] * gr[:, :, :, 2]
        Hes[:, :, :, 5] = gr[:, :, :, 1] * gr[:, :, :, 2]
        if sett.method == 'super-resolution':
            Hes *= CtC[..., None]
            diff = F.conv_transpose3d(diff[None, None, ...],
                                      smo_ker,
                                      stride=ratio)[0, 0, ...]
        # Gradient
        gr *= diff[..., None]

    return ll, gr, Hes
Exemplo n.º 24
0
def get_oriented_slice(image, dim=-1, index=None, affine=None,
                       space=None, bbox=None, interpolation=1,
                       transpose_sagittal=False, return_index=False,
                       return_mat=False):
    """Sample a slice in a RAS system

    Parameters
    ----------
    image : (..., *shape3)
    dim : int, default=-1
        Index of spatial dimension to sample in the visualization space
        If RAS: -1 = axial / -2 = coronal / -3 = sagittal
    index : int, default=shape//2
        Coordinate (in voxel) of the slice to extract
    affine : (4, 4) tensor, optional
        Orientation matrix of the image
    space : (4, 4) tensor, optional
        Orientation matrix of the visualisation space.
        Default: RAS with minimum voxel size of all inputs.
    bbox : (2, D) tensor_like, optional
        Bounding box: min and max coordinates (in millimetric
        visualisation space). Default: bounding box of the input image.
    interpolation : {0, 1}, default=1
        Interpolation order.

    Returns
    -------
    slice : (..., *shape2) tensor
        Slice in the visualisation space.

    """
    # preproc dim
    if isinstance(dim, str):
        dim = dim.lower()[0]
        if dim == 'a':
            dim = -1
        if dim == 'c':
            dim = -2
        if dim == 's':
            dim = -3

    backend = utils.backend(image)

    # compute default space (mn/mx are in voxels)
    affine, shape = _get_default_native(affine, image.shape[-3:])
    space, mn, mx = _get_default_space(affine, [shape], space, bbox)
    affine, shape = (affine[0], shape[0])

    # compute default cursor (in voxels)
    if index is None:
        index = (mx + mn) / 2
    else:
        index = torch.as_tensor(index)
        index = spatial.affine_matvec(spatial.affine_inv(space), index)

    # include slice to volume matrix
    shape = tuple(((mx-mn) + 1).round().int().tolist())
    if dim == -1:
        # axial
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 1, 0, - mn[1] + 1],
                 [0, 0, 1, - index[2]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = shape[:-1]
        index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1)
    elif dim == -2:
        # coronal
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 0, 1, - mn[2] + 1],
                 [0, 1, 0, - index[1]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = (shape[0], shape[2])
        index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1)
    elif dim == -3:
        # sagittal
        if not transpose_sagittal:
            shift = [[0, 0, 1, - mn[2] + 1],
                     [0, 1, 0, - mn[1] + 1],
                     [1, 0, 0, - index[0]],
                     [0, 0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[2], shape[1])
            index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1)
        else:
            shift = [[0, -1, 0,   mx[1] + 1],
                     [0,  0, 1, - mn[2] + 1],
                     [1,  0, 0, - index[0]],
                     [0,  0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[1], shape[2])
            index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1)
    else:
        raise ValueError(f'Unknown dimension {dim}')

    # sample
    space = spatial.affine_rmdiv(space, shift)
    affine = spatial.affine_lmdiv(affine, space)
    affine = affine.to(**backend)
    grid = spatial.affine_grid(affine, [*shape, 1])
    *channel, s0, s1, s2 = image.shape
    imshape = (s0, s1, s2)
    image = image.reshape([1, -1, *imshape])
    image = spatial.grid_pull(image, grid[None], interpolation=interpolation,
                              bound='dct2', extrapolate=False)
    image = image.reshape([*channel, *shape])
    return ((image, index, space) if return_index and return_mat else
            (image, index) if return_index else
            (image, space) if return_mat else image)
Exemplo n.º 25
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            trf.dat = spatial.exp(trf.dat[None],
                                  displacement=True,
                                  inverse=trf.inv)[0]
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.unit == 'mm':
                # convert mm displacement to vox displacement
                trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None])
                trf.dat = trf.dat[..., 0]
                trf.unit = 'vox'

    def build_from_target(target):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(target.affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.spline)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.spline
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    if options.target:
        # If target is provided, we build a dense transformation field
        grid = build_from_target(options.target)
        oaffine = options.target.affine
        if options.output_unit[0] == 'v':
            grid = spatial.affine_matvec(spatial.affine_inv(oaffine), grid)
            grid = grid - spatial.identity_grid(grid.shape[:-1],
                                                **utils.backend(grid))
        else:
            grid = grid - spatial.affine_grid(
                oaffine.to(**utils.backend(grid)), grid.shape[:-1])
        io.volumes.savef(grid,
                         options.output.format(ext='.nii.gz'),
                         affine=oaffine)
    else:
        if len(options.transformations) > 1:
            raise RuntimeError('Something weird happened: '
                               'multiple transforms and no target')
        io.transforms.savef(options.transformations[0].affine,
                            options.output.format(ext='.lta'))
Exemplo n.º 26
0
    def forward():
        """Forward pass up to the loss"""

        loss = 0

        # affine matrix
        A = None
        for trf in options.transformations:
            trf.update()
            if isinstance(trf, struct.Linear):
                q = trf.optdat.to(**backend)
                # print(q.tolist())
                B = trf.basis.to(**backend)
                A = linalg.expm(q, B)
                if torch.is_tensor(trf.shift):
                    # include shift
                    shift = trf.shift.to(**backend)
                    eye = torch.eye(options.dim, **backend)
                    A = A.clone()  # needed because expm is a custom autograd.Function
                    A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift)
                for loss1 in trf.losses:
                    loss += loss1.call(q)
                break

        # non-linear displacement field
        d = None
        d_aff = None
        for trf in options.transformations:
            if not trf.isfree():
                continue
            if isinstance(trf, struct.FFD):
                d = trf.dat.to(**backend)
                d = ffd_exp(d, trf.shape, returns='disp')
                for loss1 in trf.losses:
                    loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break
            elif isinstance(trf, struct.Diffeo):
                d = trf.dat.to(**backend)
                if not trf.smalldef:
                    # penalty on velocity fields
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d = spatial.exp(d[None], displacement=True)[0]
                if trf.smalldef:
                    # penalty on exponentiated transform
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break

        # loop over image pairs
        for match in options.losses:
            if not match.fixed:
                continue
            nb_levels = len(match.fixed.dat)
            prm = dict(interpolation=match.moving.interpolation,
                       bound=match.moving.bound,
                       extrapolate=match.moving.extrapolate)
            # loop over pyramid levels
            for moving, fixed in zip(match.moving.dat, match.fixed.dat):
                moving_dat, moving_aff = moving
                fixed_dat, fixed_aff = fixed

                moving_dat = moving_dat.to(**backend)
                moving_aff = moving_aff.to(**backend)
                fixed_dat = fixed_dat.to(**backend)
                fixed_aff = fixed_aff.to(**backend)

                # affine-corrected moving space
                if A is not None:
                    Ms = affine_matmul(A, moving_aff)
                else:
                    Ms = moving_aff

                if d is not None:
                    # fixed to param
                    Mt = affine_lmdiv(d_aff, fixed_aff)
                    if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]):
                        g = smalldef(d)
                    else:
                        g = affine_grid(Mt, fixed_dat.shape[1:])
                        g = g + pull_grid(d, g)
                    # param to moving
                    Ms = affine_lmdiv(Ms, d_aff)
                    g = affine_matvec(Ms, g)
                else:
                    # fixed to moving
                    Mt = fixed_aff
                    Ms = affine_lmdiv(Ms, Mt)
                    g = affine_grid(Ms, fixed_dat.shape[1:])

                # pull moving image
                warped_dat = pull(moving_dat, g, **prm)
                loss += match.call(warped_dat, fixed_dat) / float(nb_levels)

                # import matplotlib.pyplot as plt
                # plt.subplot(1, 2, 1)
                # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.subplot(1, 2, 2)
                # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.show()

        return loss
Exemplo n.º 27
0
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None,
           interpolation=1, bound='dct2', extrapolate=False, device=None,
           verbose=True):
    """Apply the linear and non-linear components of the transform and
    reslice the image to the target space.

    Notes
    -----
    .. The shape and general orientation of the moving image is kept untouched.
    .. The linear transform is composed with the original orientation matrix.
    .. The non-linear component is "wrapped" in the input space, where it
       is applied.
    .. This function writes a new file (it does not modify the input files
       in place).

    Parameters
    ----------
    moving : ImageFile
        An object describing a moving image.
    fname : list of str
        Output filename for each input file of the moving image
        (since images can be encoded over multiple volumes)
    like : ImageFile
        An object describing the target space
    inv : bool, default=False
        True if we are warping the fixed image to the moving space.
        In the case, `moving` should be a `FixedImageFile` and `like` a
        `MovingImageFile`. Else it should be a `MovingImageFile` and `'like`
        a `FixedImageFile`.
    lin : (4, 4) tensor, optional
        Linear (or rather affine) transformation
    nonlin : dict, optional
        Non-linear displacement field, with keys:
        disp : (..., 3) tensor
            Displacement field (in voxels)
        affine : (4, 4) tensor
            Orientation matrix of the displacement field
    interpolation : int, default=1
    bound : str, default='dct2'
    extrapolate : bool, default = False
    device : default='cpu'

    """
    nonlin = nonlin or dict(disp=None, affine=None)
    prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate)

    moving_affine = moving.affine.to(device)
    fixed_affine = like.affine.to(device)

    if inv:
        # affine-corrected fixed space
        if lin is not None:
            fix2lin = affine_matmul(lin, fixed_affine)
        else:
            fix2lin = fixed_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param to moving
            nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(moving_affine, fix2lin)
            g = affine_grid(g, like.shape)

    else:
        # affine-corrected moving space
        if lin is not None:
            mov2nlin = affine_matmul(lin, moving_affine)
        else:
            mov2nlin = moving_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param voxels to moving voxels (warps moving to fixed)
            nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(mov2nlin, fixed_affine)
            g = affine_grid(g, like.shape)

    if moving.type == 'labels':
        prm['interpolation'] = 0
    for file, ofname in zip(moving.files, fname):
        if verbose:
            print(f'Resliced:   {file.fname}\n'
                  f'         -> {ofname}')
        dat = io.volumes.loadf(file.fname, rand=True, device=device)
        dat = dat.reshape([*file.shape, file.channels])
        if g is not None:
            dat = utils.movedim(dat, -1, 0)
            dat = pull(dat, g, **prm)
            dat = utils.movedim(dat, 0, -1)
        io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
Exemplo n.º 28
0
def fit_affine_tpm(dat,
                   tpm,
                   affine=None,
                   affine_tpm=None,
                   weights=None,
                   basis='affine',
                   fwhm=None,
                   joint=False,
                   prm=None,
                   max_iter_gn=100,
                   max_iter_em=32,
                   max_line_search=6,
                   progressive=False,
                   verbose=1):
    """

    Parameters
    ----------
    dat : (B, J|1, *spatial) tensor
    tpm : (B|1, K, *spatial) tensor
    affine : (4, 4) tensor
    affine_tpm : (4, 4) tensor
    weights : (B, 1, *spatial) tensor
    basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'}
    fwhm : float, default=J/32
    joint : bool, default=False
    max_iter_gn : int, default=100
    max_iter_em : int, default=32
    max_line_search : int, default=12
    progressive : bool, default=False

    Returns
    -------
    mi : (B,) tensor
    aff : (B, 4, 4) tensor
    prm : (B, F) tensor

    """
    dim = dat.dim() - 2

    # ------------------------------------------------------------------
    #       RECURSIVE PROGRESSIVE FIT
    # ------------------------------------------------------------------
    if progressive:
        nb_se = dim * (dim + 1) // 2
        nb_aff = dim * (dim + 1)
        basis_recursion = {'Aff+': 'CSO', 'CSO': 'SE', 'SE': 'T'}
        basis_nb_feat = {'Aff+': nb_aff, 'CSO': nb_se + 1, 'SE': nb_se}
        basis = convert_basis(basis)
        next_basis = basis_recursion.get(basis, None)
        if next_basis:
            *_, prm = fit_affine_tpm(dat,
                                     tpm,
                                     affine,
                                     affine_tpm,
                                     weights,
                                     basis=next_basis,
                                     fwhm=fwhm,
                                     joint=joint,
                                     prm=prm,
                                     max_iter_gn=max_iter_gn,
                                     max_iter_em=max_iter_em,
                                     max_line_search=max_line_search)
            B = len(dat)
            F = basis_nb_feat[basis]
            prm0 = prm
            prm = prm0.new_zeros([1 if joint else B, F])
            if basis == 'SE':
                prm[:, :dim] = prm0[:, :dim]
            else:
                nb_se = dim * (dim + 1) // 2
                prm[:, :nb_se] = prm0[:, :nb_se]
                if basis == 'Aff+':
                    prm[:, nb_se:nb_se + dim] = prm0[:, nb_se] * (dim**(-0.5))

    basis_name = basis

    # ------------------------------------------------------------------
    #       PREPARE
    # ------------------------------------------------------------------

    B = len(dat)
    if affine is None:
        affine = spatial.affine_default(dat.shape[-dim:])
    if affine_tpm is None:
        affine_tpm = spatial.affine_default(tpm.shape[-dim:])
    affine = affine.to(**utils.backend(tpm))
    affine_tpm = affine_tpm.to(**utils.backend(tpm))
    shape = dat.shape[-dim:]

    tpm = tpm.to(dat.device)
    basis = make_basis(basis, dim, **utils.backend(tpm))
    F = len(basis)

    if prm is None:
        prm = tpm.new_zeros([1 if joint else B, F])
    aff, gaff = linalg._expm(prm, basis, grad_X=True)

    em_opt = dict(fwhm=fwhm,
                  max_iter=max_iter_em,
                  weights=weights,
                  verbose=verbose - 2)
    drv_opt = dict(weights=weights)
    pull_opt = dict(bound='replicate', extrapolate=True)

    # ------------------------------------------------------------------
    #       OPTIMIZE
    # ------------------------------------------------------------------
    prior = None
    mi = torch.as_tensor(-float('inf'))
    delta = torch.zeros_like(prm)
    for n_iter in range(max_iter_gn):

        # --------------------------------------------------------------
        #       LINE SEARCH
        # --------------------------------------------------------------
        prior0, prm0, mi0 = prior, prm, mi
        armijo = 1
        success = False
        for n_ls in range(max_line_search):

            # --- take a step ------------------------------------------
            prm = prm0 - armijo * delta

            # --- build transformation field ---------------------------
            aff, gaff = linalg._expm(prm, basis, grad_X=True)
            phi = lmdiv(affine_tpm, mm(aff, affine))
            phi = spatial.affine_grid(phi, shape)

            # --- warp TPM ---------------------------------------------
            mov = spatial.grid_pull(tpm, phi, **pull_opt)

            # --- mutual info ------------------------------------------
            mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt)
            mi = mi / Nm

            success = mi.sum() > mi0.sum()
            if verbose >= 2:
                end = '\n' if verbose >= 3 else '\r'
                happy = ':D' if success else ':('
                print(f'(search) | {n_ls:02d} | {mi.mean():12.6g} | {happy}',
                      end=end)
            if success:
                break
            armijo *= 0.5
        # if verbose == 2:
        #     print('')

        # --------------------------------------------------------------
        #       DID IT WORK?
        # --------------------------------------------------------------

        if not success:
            prior, prm, mi = prior0, prm0, mi0
            break

        # DEBUG
        # plot_registration(dat, mov, f'{basis_name} | {n_iter}')

        space = ' ' * max(0, 6 - len(basis_name))
        if verbose >= 1:
            end = '\n' if verbose >= 2 else '\r'
            print(
                f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}',
                end=end)

        if mi.mean() - mi0.mean() < 1e-5:
            break

        # --------------------------------------------------------------
        #       GAUSS-NEWTON
        # --------------------------------------------------------------

        # --- derivatives ----------------------------------------------
        g, h = derivatives_intensity(mov, dat, prior, **drv_opt)

        # --- chain rule -----------------------------------------------
        gmov = spatial.grid_grad(tpm, phi, **pull_opt)
        if joint and len(mov) == 1:
            g = g.sum(0, keepdim=True)
            h = h.sum(0, keepdim=True)
        else:
            gmov = gmov.expand([B, *gmov.shape[1:]])
        gaff = lmdiv(affine_tpm, mm(gaff, affine))
        g, h = chain_rule(g, h, gmov, gaff, maj=False)
        del gmov

        if joint and len(g) > 1:
            g = g.sum(0, keepdim=True)
            h = h.sum(0, keepdim=True)

        # --- Gauss-Newton ---------------------------------------------
        delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1)

    if verbose == 1:
        print('')
    return mi, aff, prm
Exemplo n.º 29
0
    def compose(self,
                orient_in,
                deformation,
                orient_mean,
                affine=None,
                orient_out=None,
                shape_out=None):
        """Composes a deformation defined in a mean space to an image space.

        Parameters
        ----------
        orient_in : (4, 4) tensor
            Orientation of the input image
        deformation : (*shape_mean, 3) tensor
            Random deformation
        orient_mean : (4, 4) tensor
            Orientation of the mean space (where the deformation is)
        affine : (4, 4) tensor, default=identity
            Random affine
        orient_out : (4, 4) tensor, default=orient_in
            Orientation of the output image
        shape_out : sequence[int], default=shape_mean
            Shape of the output image

        Returns
        -------
        grid : (*shape_out, 3)
            Voxel-to-voxel transform

        """
        if orient_out is None:
            orient_out = orient_in
        if shape_out is None:
            shape_out = deformation.shape[:-1]
        if affine is None:
            affine = torch.eye(4,
                               4,
                               device=orient_in.device,
                               dtype=orient_in.dtype)
        shape_mean = deformation.shape[:-1]

        orient_in, affine, deformation, orient_mean, orient_out \
            = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out)
        backend = utils.backend(deformation)
        eye = torch.eye(4, **backend)

        # Compose deformation on the right
        right_affine = spatial.affine_lmdiv(orient_mean, orient_out)
        if not (shape_mean == shape_out and right_affine.all_close(eye)):
            # the mean space and native space are not the same
            # we must compose the diffeo with a dense affine transform
            # we write the diffeo as an identity plus a displacement
            # (id + disp)(aff) = aff + disp(aff)
            # -------
            # to displacement
            deformation = deformation - spatial.identity_grid(
                deformation.shape[:-1], **backend)
            trf = spatial.affine_grid(right_affine, shape_out)
            deformation = spatial.grid_pull(utils.movedim(deformation, -1,
                                                          0)[None],
                                            trf[None],
                                            bound='dft',
                                            extrapolate=True)
            deformation = utils.movedim(deformation[0], 0, -1)
            trf = trf + deformation  # add displacement

        # Compose deformation on the left
        #   the output of the diffeo(right) are mean_space voxels
        #   we must compose on the left with `in\(aff(mean))`
        # -------
        left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in),
                                            affine)
        left_affine = spatial.affine_matmul(left_affine, orient_mean)
        trf = spatial.affine_matvec(left_affine, trf)

        return trf
Exemplo n.º 30
0
def _proj_apply(operator,
                dat,
                po,
                method='super-resolution',
                bound='zero',
                interpolation='linear'):
    """ Applies operator A, At  or AtA (for denoising or super-resolution).

    Args:
        operator (string): Either 'A', 'At', 'AtA' or 'none'.
        dat (torch.tensor()): Image data (1, 1, X_in, Y_in, Z_in).
        po (_proj_op()): Encodes projection operator, has the following fields:
            po.mat_x: Low-res affine matrix.
            po.mat_y: High-res affine matrix.
            po.mat_yx: Intermediate affine matrix.
            po.dim_x: Low-res image dimensions.
            po.dim_y: High-res image dimensions.
            po.dim_yx: Intermediate image dimensions.
            po.ratio: The ratio (low-res voxel_size)/(high-res voxel_size).
            po.smo_ker: Smoothing kernel (slice-profile).
        method (string): Either 'denoising' or 'super-resolution' (default).
        bound (str, optional): Bound for nitorch push/pull, defaults to 'zero'.
        interpolation (int, optional): Interpolation order, defaults to linear.

    Returns:
        dat (torch.tensor()): Projected image data (1, 1, X_out, Y_out, Z_out).

    """
    # Sanity check
    if operator not in ['A', 'At', 'AtA', 'none']:
        raise ValueError('Undefined operator')
    if method not in ['denoising', 'super-resolution']:
        raise ValueError('Undefined method')
    if operator == 'none':
        # No projection
        return dat
    # Get data type and device
    dtype = dat.dtype
    device = dat.device
    # Parse required projection info
    mat_x = po.mat_x
    mat_y = po.mat_y
    mat_yx = po.mat_yx
    rigid = po.rigid
    dim_x = po.dim_x
    dim_y = po.dim_y
    dim_yx = po.dim_yx
    ratio = po.ratio
    smo_ker = po.smo_ker
    scl = po.scl
    dim_thick = po.dim_thick
    if method == 'super-resolution':
        dim = dim_yx
        mat = rigid.mm(mat_yx).solve(mat_y)[0]  # mat_y\rigid*mat_yx
    elif method == 'denoising':
        dim = dim_x
        mat = rigid.mm(mat_x).solve(mat_y)[0]  # mat_y\rigid*mat_x
    # Smoothing operator
    if len(ratio) == 3:  # 3D
        conv = lambda x: F.conv3d(x, smo_ker, stride=ratio)
        conv_transpose = lambda x: F.conv_transpose3d(x, smo_ker, stride=ratio)
    else:  # 2D
        conv = lambda x: F.conv2d(x, smo_ker, stride=ratio)
        conv_transpose = lambda x: F.conv_transpose2d(x, smo_ker, stride=ratio)
    # Get grid
    grid = affine_grid(mat.type(dat.dtype), dim, jitter=False)[None, ...]
    # Apply projection
    if method == 'super-resolution':
        extrapolate = False
        if operator == 'A':
            dat = grid_pull(dat,
                            grid,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
            dat = conv(dat)
            if scl != 0:
                dat = _apply_scaling(dat, scl, dim_thick)
        elif operator == 'At':
            if scl != 0:
                dat = _apply_scaling(dat, scl, dim_thick)
            dat = conv_transpose(dat)
            dat = grid_push(dat,
                            grid,
                            shape=dim_y,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
        elif operator == 'AtA':
            dat = grid_pull(dat,
                            grid,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
            dat = conv(dat)
            if scl != 0:
                dat = _apply_scaling(dat, 2 * scl, dim_thick)
            dat = conv_transpose(dat)
            dat = grid_push(dat,
                            grid,
                            shape=dim_y,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
    elif method == 'denoising':
        extrapolate = False
        if operator == 'A':
            dat = grid_pull(dat,
                            grid,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
        elif operator == 'At':
            dat = grid_push(dat,
                            grid,
                            shape=dim_y,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
        elif operator == 'AtA':
            dat = grid_pull(dat,
                            grid,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)
            dat = grid_push(dat,
                            grid,
                            shape=dim_y,
                            bound=bound,
                            extrapolate=extrapolate,
                            interpolation=interpolation)

    return dat