Example #1
0
    def forward(self, affine, **overload):
        """

        Parameters
        ----------
        affine : (batch, ndim[+1], ndim+1) tensor
            Affine matrix
        overload : dict
            All parameters of the module can be overridden at call time.

        Returns
        -------
        grid : (batch, *shape, ndim) tensor
            Dense transformation grid

        """

        nb_dim = affine.shape[-1] - 1
        info = {'dtype': affine.dtype, 'device': affine.device}
        shape = make_list(overload.get('shape', self.shape), nb_dim)
        shift = overload.get('shift', self.shift)

        if shift:
            affine_shift = torch.cat((
                torch.eye(nb_dim, **info),
                -torch.as_tensor(shape, **info)[:, None]/2),
                dim=1)
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

        grid = spatial.affine_grid(affine, shape)
        return grid
Example #2
0
    def forward(self, affine, shape=None):
        """

        Parameters
        ----------
        affine : (batch, ndim[+1], ndim+1) tensor
            Affine matrix
        shape : sequence[int], default=self.shape

        Returns
        -------
        grid : (batch, *shape, ndim) tensor
            Dense transformation grid

        """

        nb_dim = affine.shape[-1] - 1
        backend = {'dtype': affine.dtype, 'device': affine.device}
        shape = shape or self.shape

        if self.shift:
            affine_shift = torch.eye(nb_dim + 1, **backend)
            affine_shift[:nb_dim, -1] = torch.as_tensor(shape, **backend)
            affine_shift[:nb_dim, -1].sub(1).div(2).neg()
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

        grid = spatial.affine_grid(affine, shape)
        return grid
Example #3
0
 def pull(q, vel):
     grid = spatial.exp(vel)
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     grid = spatial.affine_matvec(aff, grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Example #4
0
 def pull(q, grid):
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
     grid = spatial.affine_matvec(aff[expd], grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Example #5
0
def _warp_image1(image,
                 target,
                 shape=None,
                 affine=None,
                 nonlin=None,
                 backward=False,
                 reslice=False):
    """Returns the warped image, with channel dimension last"""
    # build transform
    aff_right = target
    aff_left = spatial.affine_inv(image.affine)
    aff = None
    if affine:
        # exp = affine.iexp if backward else affine.exp
        exp = affine.exp
        aff = exp(recompute=False, cache_result=True)
        if backward:
            aff = spatial.affine_inv(aff)
    if nonlin:
        if affine:
            if affine.position[0].lower() in ('ms' if backward else 'fs'):
                aff_right = spatial.affine_matmul(aff, aff_right)
            if affine.position[0].lower() in ('fs' if backward else 'ms'):
                aff_left = spatial.affine_matmul(aff_left, aff)
        exp = nonlin.iexp if backward else nonlin.exp
        phi = exp(recompute=False, cache_result=True)
        aff_left = spatial.affine_matmul(aff_left, nonlin.affine)
        aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right)
        if _almost_identity(aff_right) and nonlin.shape == shape:
            phi = nonlin.add_identity(phi)
        else:
            tmp = spatial.affine_grid(aff_right, shape)
            phi = regutils.smart_pull_grid(phi, tmp).add_(tmp)
            del tmp
        if not _almost_identity(aff_left):
            phi = spatial.affine_matvec(aff_left, phi)
    else:
        # no nonlin: single affine even if position == 'symmetric'
        if reslice:
            aff = spatial.affine_matmul(aff, aff_right)
            aff = spatial.affine_matmul(aff_left, aff)
            phi = spatial.affine_grid(aff, shape)
        else:
            phi = None

    # warp image
    if phi is not None:
        warped = image.pull(phi)
    else:
        warped = image.dat

    # write to disk
    if len(warped) == 1:
        warped = warped[0]
    else:
        warped = utils.movedim(warped, 0, -1)
    return warped
Example #6
0
 def slice_to(self, stack, cache_result=False, recompute=True):
     aff = self.exp(cache_result=cache_result, recompute=recompute)
     if recompute or not hasattr(self, '_sliced'):
         aff = spatial.affine_matmul(aff, self.affine)
         aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout)
         aff = spatial.affine_lmdiv(aff_reorient, aff)
         aff = spatial.affine_grid(aff, self.shape)
         sliced = spatial.grid_pull(self.dat, aff, bound=self.bound,
                                    extrapolate=self.extrapolate)
         fwhm = [0] * self.dim
         fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1]
         sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound)
         slices = []
         for stack_slice in stack.slices:
             aff = spatial.affine_matmul(stack.affine, )
             aff = spatial.affine_lmdiv(aff_reorient, )
     if cache_result:
         self._sliced = sliced
     return sliced
Example #7
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch shape.

        Other Parameters
        ----------------
        shape : sequence[int], optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """
        shape = overload.get('shape', self.grid.velocity.field.shape)
        dtype = overload.get('dtype', self.grid.velocity.field.dtype)
        device = overload.get('device', self.grid.velocity.field.device)
        backend = dict(dtype=dtype, device=device)

        if self.grid.velocity.field.amplitude == 0:
            grid = identity_grid(shape, **backend)
        else:
            grid = self.grid(batch, shape=shape, **backend)
        dtype = grid.dtype
        device = grid.device
        backend = dict(dtype=dtype, device=device)

        shape = grid.shape[1:-1]
        dim = len(shape)
        aff = self.affine(batch, dim=dim, **backend)

        # shift center of rotation
        aff_shift = torch.cat((
            torch.eye(dim, **backend),
            torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)),
            dim=1)
        aff_shift = as_euclidean(aff_shift)

        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = utils.unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = linalg.matvec(lin, grid) + off

        return grid
Example #8
0
    def exp(self, velocity, affine=None, displacement=False):
        """Generate a deformation grid from tangent parameters.

        Parameters
        ----------
        velocity : (batch, *spatial, nb_dim)
            Stationary velocity field
        affine : (batch, nb_prm)
            Affine parameters
        displacement : bool, default=False
            Return a displacement field (voxel to shift) rather than
            a transformation field (voxel to voxel).

        Returns
        -------
        grid : (batch, *spatial, nb_dim)
            Deformation grid (transformation or displacment).

        """
        info = {'dtype': velocity.dtype, 'device': velocity.device}

        # generate grid
        shape = velocity.shape[1:-1]
        velocity_small = self.resize(velocity, type='displacement')
        grid = self.velexp(velocity_small)
        grid = self.resize(grid, shape=shape, type='grid')

        if affine is not None:
            # exponentiate
            affine_prm = affine
            affine = []
            for prm in affine_prm:
                affine.append(self.affexp(prm))
            affine = torch.stack(affine, dim=0)

            # shift center of rotation
            affine_shift = torch.cat(
                (torch.eye(self.dim, **info),
                 -torch.as_tensor(shape, **info)[:, None] / 2),
                dim=1)
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

            # compose
            affine = unsqueeze(affine, dim=-3, ndim=self.dim)
            lin = affine[..., :self.dim, :self.dim]
            off = affine[..., :self.dim, -1]
            grid = matvec(lin, grid) + off

        if displacement:
            grid = grid - spatial.identity_grid(grid.shape[1:-1], **info)

        return grid
Example #9
0
def _crop_to_param(aff0, aff, shape):
    dim = aff0.shape[-1] - 1
    shape = shape[:dim]
    layout0 = spatial.affine_to_layout(aff0)
    layout = spatial.affine_to_layout(aff)
    if (layout0 != layout).any():
        raise ValueError('Input and Ref do not have the same layout: '
                         f'{spatial.volume_layout_to_name(layout0)} vs '
                         f'{spatial.volume_layout_to_name(layout)}.')
    size = shape
    layout = None
    unit = 'vox'

    center = torch.as_tensor(shape, dtype=torch.float).sub_(1).mul_(0.5)
    like_aff = spatial.affine_lmdiv(aff0, aff)
    center = spatial.affine_matvec(like_aff, center)

    return size, center, unit, layout
Example #10
0
    def __call__(self, logaff, grad=False, hess=False, in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            logaff.requires_grad_()
            if logaff.grad is not None:
                logaff.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(logaff, grad, in_line_search=in_line_search)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(logaff, grad, in_line_search=in_line_search)

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                   and (self.n_iter - 1) % logplot == 0

        # jitter
        # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True,
        #                             **utils.backend(self.fixed))
        # fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        # del idj
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape, **utils.backend(logaff))
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        llx = self.loss.loss(warped, fixed)
        del warped

        # print objective
        lll = llx
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [lll]
        if grad is not False:
            lll.backward()
            grad = logaff.grad.clone()
            out.append(grad)
        logaff.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]
Example #11
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            trf.dat = spatial.exp(trf.dat[None],
                                  displacement=True,
                                  inverse=trf.inv)[0]
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.unit == 'mm':
                # convert mm displacement to vox displacement
                trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None])
                trf.dat = trf.dat[..., 0]
                trf.unit = 'vox'

    def build_from_target(target):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(target.affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.spline)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.spline
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    if options.target:
        # If target is provided, we build a dense transformation field
        grid = build_from_target(options.target)
        oaffine = options.target.affine
        if options.output_unit[0] == 'v':
            grid = spatial.affine_matvec(spatial.affine_inv(oaffine), grid)
            grid = grid - spatial.identity_grid(grid.shape[:-1],
                                                **utils.backend(grid))
        else:
            grid = grid - spatial.affine_grid(
                oaffine.to(**utils.backend(grid)), grid.shape[:-1])
        io.volumes.savef(grid,
                         options.output.format(ext='.nii.gz'),
                         affine=oaffine)
    else:
        if len(options.transformations) > 1:
            raise RuntimeError('Something weird happened: '
                               'multiple transforms and no target')
        io.transforms.savef(options.transformations[0].affine,
                            options.output.format(ext='.lta'))
Example #12
0
    def do_vel(self, vel, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the nonlinear component"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        if self.affine:
            aff0, iaff0 = self.affine.exp2(cache_result=True, recompute=False)
            aff_pos = self.affine.position[0].lower()
        else:
            aff_pos = 'x'
            aff0 = iaff0 = torch.eye(self.nonlin.dim + 1)
        vel0 = vel
        if any(loss.backward for loss in self.losses):
            phi0, iphi0 = self.nonlin.exp2(vel0,
                                           recompute=True,
                                           cache_result=not in_line_search)
            ivel0 = -vel0
        else:
            phi0 = self.nonlin.exp(vel0,
                                   recompute=True,
                                   cache_result=not in_line_search)
            iphi0 = ivel0 = None
        aff0 = aff0.to(phi0)
        iaff0 = iaff0.to(phi0)

        # ==============================================================
        #                     ACCUMULATE DERIVATIVES
        # ==============================================================

        has_printed = False
        for loss in self.losses:

            # ==========================================================
            #                     ONE LOSS COMPONENT
            # ==========================================================
            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, vel00 = iphi0, iaff0, ivel0
            else:
                phi00, aff00, vel00 = phi0, aff0, vel0

            # ----------------------------------------------------------
            # build left and right affine
            # ----------------------------------------------------------
            aff_right = fixed.affine
            if aff_pos in 'fs':  # affine position: fixed or symmetric
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left = self.nonlin.affine
            if aff_pos in 'ms':  # affine position: moving or symmetric
                aff_left = aff00 @ self.nonlin.affine
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                aff_right = None
                phi = spatial.add_identity_grid(phi00)
                disp = phi00
            else:
                phi = spatial.affine_grid(aff_right, fixed.shape)
                disp = regutils.smart_pull_grid(phi00, phi)
                phi += disp
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                aff_left = None
            else:
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, disp, dim=fixed.dim,
                             title=f'(nonlin) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt phi
            # ----------------------------------------------------------
            if grad or hess:

                g, h, mugrad = self.nonlin.propagate_grad(
                    g, h, moving, phi00, aff_left, aff_right,
                    inv=loss.backward)
                g = regutils.jg(mugrad, g)
                h = regutils.jhj(mugrad, h)
                if isinstance(self.nonlin, SVFModel):
                    # propagate backward by scaling and squaring
                    g, h = spatial.exp_backward(vel00, g, h,
                                                steps=self.nonlin.steps)

                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # ==============================================================
        #                       REGULARIZATION
        # ==============================================================
        vgrad = self.nonlin.regulariser(vel0)
        llv = 0.5 * vel0.flatten().dot(vgrad.flatten())
        if grad:
            sumgrad += vgrad
        del vgrad

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += llv
        sumloss += self.lla
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            llv = llv.item()
            ll = sumloss.item()
            lla = self.lla
            if in_line_search:
                line = '(search) | '
            else:
                line = '(nonlin) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.llv = llv
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Example #13
0
def _warp_image(option,
                affine=None,
                nonlin=None,
                dim=None,
                device=None,
                odir=None):
    """Warp and save the moving and fixed images from a loss object"""

    if not (option.mov.output or option.mov.resliced or option.fix.output
            or option.fix.resliced):
        return

    fix, fix_affine = _map_image(option.fix.files, dim=dim)
    mov, mov_affine = _map_image(option.mov.files, dim=dim)
    fix_affine = fix_affine.float()
    mov_affine = mov_affine.float()
    dim = dim or (fix.dim - 1)

    if option.fix.world:  # overwrite orientation matrix
        fix_affine = io.transforms.map(option.fix.world).fdata().squeeze()
    for transform in (option.fix.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        fix_affine = spatial.affine_lmdiv(transform, fix_affine)

    if option.mov.world:  # overwrite orientation matrix
        mov_affine = io.transforms.map(option.mov.world).fdata().squeeze()
    for transform in (option.mov.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        mov_affine = spatial.affine_lmdiv(transform, mov_affine)

    # moving
    if option.mov.output or option.mov.resliced:
        ifname = option.mov.files[0]
        idir, base, ext = py.fileparts(ifname)
        odir_mov = odir or idir or '.'

        image = objects.Image(mov.fdata(rand=True, device=device),
                              dim=dim,
                              affine=mov_affine,
                              bound=option.mov.bound,
                              extrapolate=option.mov.extrapolate)

        if option.mov.output:
            target_affine = mov_affine
            target_shape = image.shape
            if affine and affine.position[0].lower() in 'ms':
                aff = affine.exp(recompute=False, cache_result=True)
                target_affine = spatial.affine_lmdiv(aff, target_affine)

            fname = option.mov.output.format(dir=odir_mov,
                                             base=base,
                                             sep=os.path.sep,
                                             ext=ext)
            print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

        if option.mov.resliced:
            target_affine = fix_affine
            target_shape = fix.shape[1:]

            fname = option.mov.resliced.format(dir=odir_mov,
                                               base=base,
                                               sep=os.path.sep,
                                               ext=ext)
            print(f'Full reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  reslice=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

    # fixed
    if option.fix.output or option.fix.resliced:
        ifname = option.fix.files[0]
        idir, base, ext = py.fileparts(ifname)
        odir_fix = odir or idir or '.'

        image = objects.Image(fix.fdata(rand=True, device=device),
                              dim=dim,
                              affine=fix_affine,
                              bound=option.fix.bound,
                              extrapolate=option.fix.extrapolate)

        if option.fix.output:
            target_affine = fix_affine
            target_shape = image.shape
            if affine and affine.position[0].lower() in 'fs':
                aff = affine.exp(recompute=False, cache_result=True)
                target_affine = spatial.affine_matmul(aff, target_affine)

            fname = option.fix.output.format(dir=odir_fix,
                                             base=base,
                                             sep=os.path.sep,
                                             ext=ext)
            print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  backward=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

        if option.fix.resliced:
            target_affine = mov_affine
            target_shape = mov.shape[1:]

            fname = option.fix.resliced.format(dir=odir_fix,
                                               base=base,
                                               sep=os.path.sep,
                                               ext=ext)
            print(f'Full reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  backward=True,
                                  reslice=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped
Example #14
0
    def do_affine_only(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff0, iaff0, gaff0, igaff0 = self.affine.exp2(logaff0, grad=True)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                aff00, gaff00 = iaff0, igaff0
            else:
                aff00, gaff00 = aff0, gaff0

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            aff = aff00 @ fixed.affine
            aff = linalg.lmdiv(moving.affine, aff)
            gaff = gaff00 @ fixed.affine
            gaff = linalg.lmdiv(moving.affine, gaff)
            phi = spatial.affine_grid(aff, fixed.shape)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.preview
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.preview, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            # compose with spatial gradients
            if grad or hess:
                mugrad = moving.pull_grad(phi, rotate=False)
                g, h = compose_grad(g, h, mugrad, gaff)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        lla = lla
        ll = sumloss.item()
        self.loss_value = ll
        if self.verbose and (self.verbose > 1 or not in_line_search):
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Example #15
0
    def compose(self,
                orient_in,
                deformation,
                orient_mean,
                affine=None,
                orient_out=None,
                shape_out=None):
        """Composes a deformation defined in a mean space to an image space.

        Parameters
        ----------
        orient_in : (4, 4) tensor
            Orientation of the input image
        deformation : (*shape_mean, 3) tensor
            Random deformation
        orient_mean : (4, 4) tensor
            Orientation of the mean space (where the deformation is)
        affine : (4, 4) tensor, default=identity
            Random affine
        orient_out : (4, 4) tensor, default=orient_in
            Orientation of the output image
        shape_out : sequence[int], default=shape_mean
            Shape of the output image

        Returns
        -------
        grid : (*shape_out, 3)
            Voxel-to-voxel transform

        """
        if orient_out is None:
            orient_out = orient_in
        if shape_out is None:
            shape_out = deformation.shape[:-1]
        if affine is None:
            affine = torch.eye(4,
                               4,
                               device=orient_in.device,
                               dtype=orient_in.dtype)
        shape_mean = deformation.shape[:-1]

        orient_in, affine, deformation, orient_mean, orient_out \
            = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out)
        backend = utils.backend(deformation)
        eye = torch.eye(4, **backend)

        # Compose deformation on the right
        right_affine = spatial.affine_lmdiv(orient_mean, orient_out)
        if not (shape_mean == shape_out and right_affine.all_close(eye)):
            # the mean space and native space are not the same
            # we must compose the diffeo with a dense affine transform
            # we write the diffeo as an identity plus a displacement
            # (id + disp)(aff) = aff + disp(aff)
            # -------
            # to displacement
            deformation = deformation - spatial.identity_grid(
                deformation.shape[:-1], **backend)
            trf = spatial.affine_grid(right_affine, shape_out)
            deformation = spatial.grid_pull(utils.movedim(deformation, -1,
                                                          0)[None],
                                            trf[None],
                                            bound='dft',
                                            extrapolate=True)
            deformation = utils.movedim(deformation[0], 0, -1)
            trf = trf + deformation  # add displacement

        # Compose deformation on the left
        #   the output of the diffeo(right) are mean_space voxels
        #   we must compose on the left with `in\(aff(mean))`
        # -------
        left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in),
                                            affine)
        left_affine = spatial.affine_matmul(left_affine, orient_mean)
        trf = spatial.affine_matvec(left_affine, trf)

        return trf
Example #16
0
    def do_affine(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is not None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff_pos = self.affine.position[0].lower()
        if any(loss.backward for loss in self.losses):
            aff0, iaff0, gaff0, igaff0 = \
                self.affine.exp2(logaff0, grad=True,
                                 cache_result=not in_line_search)
            phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False)
        else:
            iaff0, igaff0, iphi0 = None, None, None
            aff0, gaff0 = self.affine.exp(logaff0, grad=True,
                                          cache_result=not in_line_search)
            phi0 = self.nonlin.exp(cache_result=True, recompute=False)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, gaff00 = iphi0, iaff0, igaff0
            else:
                phi00, aff00, gaff00 = phi0, aff0, gaff0

            # ----------------------------------------------------------
            # build left and right affine matrices
            # ----------------------------------------------------------
            aff_right, gaff_right = fixed.affine, None
            if aff_pos in 'fs':
                gaff_right = gaff00 @ aff_right
                gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right)
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left, gaff_left = self.nonlin.affine, None
            if aff_pos in 'ms':
                gaff_left = gaff00 @ aff_left
                gaff_left = linalg.lmdiv(moving.affine, gaff_left)
                aff_left = aff00 @ aff_left
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                right = None
                phi = spatial.add_identity_grid(phi00)
            else:
                right = spatial.affine_grid(aff_right, fixed.shape)
                phi = regutils.smart_pull_grid(phi00, right)
                phi += right
            phi_right = phi
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                left = None
            else:
                left = spatial.affine_grid(aff_left, self.nonlin.shape)
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            if grad or hess:
                g0, g = g, None
                h0, h = h, None
                if aff_pos in 'ms':
                    g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape)
                    h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape)
                    mugrad = moving.pull_grad(left, rotate=False)
                    g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left)
                    g, h = g_left, h_left
                if aff_pos in 'fs':
                    g_right, h_right = g0, h0
                    mugrad = moving.pull_grad(phi, rotate=False)
                    jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False)
                    jac = torch.matmul(aff_left[:-1, :-1], jac)
                    mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad)
                    g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right)
                    g = g_right if g is None else g.add_(g_right)
                    h = h_right if h is None else h.add_(h_right)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        sumloss += self.llv
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            ll = sumloss.item()
            llv = self.llv
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Example #17
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # 1) Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, struct.Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.json:
                with open(trf.json) as f:
                    prm = json.load(f)
                prm['voxel_size'] = spatial.voxel_size(trf.affine)
                trf.dat = spatial.shoot(trf.dat[None],
                                        displacement=True,
                                        return_inverse=trf.inv)
                if trf.inv:
                    trf.dat = trf.dat[-1]
            else:
                trf.dat = spatial.exp(trf.dat[None],
                                      displacement=True,
                                      inverse=trf.inv)
            trf.dat = trf.dat[0]  # drop batch dimension
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, struct.Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.unit == 'mm':
                # convert mm displacement to vox displacement
                trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None])
                trf.dat = trf.dat[..., 0]
                trf.unit = 'vox'

    # 2) If the first transform is linear, compose it with the input
    #    orientation matrix
    if (options.transformations
            and isinstance(options.transformations[0], struct.Linear)):
        trf = options.transformations[0]
        for file in options.files:
            mat = file.affine.to(**backend)
            aff = trf.affine.to(**backend)
            file.affine = spatial.affine_lmdiv(aff, mat)
        options.transformations = options.transformations[1:]

    def build_from_target(affine, shape):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(affine.to(**backend), shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, struct.Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.order)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.order
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    # 3) If target is provided, we can build most of the transform once
    #    and just multiply it with a input-wise affine matrix.
    if options.target:
        grid = build_from_target(options.target.affine, options.target.shape)
        oaffine = options.target.affine

    # 4) Loop across input files
    opt_pull0 = dict(interpolation=options.interpolation,
                     bound=options.bound,
                     extrapolate=options.extrapolate)
    opt_coeff = dict(interpolation=options.interpolation,
                     bound=options.bound,
                     dim=3,
                     inplace=True)
    output = py.make_list(options.output, len(options.files))
    for file, ofname in zip(options.files, output):
        is_label = isinstance(options.interpolation,
                              str) and options.interpolation == 'l'
        ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
        print(f'Reslicing:   {file.fname}\n' f'          -> {ofname}')
        if is_label:
            backend_int = dict(dtype=torch.long, device=backend['device'])
            dat = io.volumes.load(file.fname, **backend_int)
            opt_pull = dict(opt_pull0)
            opt_pull['interpolation'] = 1
        else:
            dat = io.volumes.loadf(file.fname,
                                   rand=options.interpolation > 0,
                                   **backend)
            opt_pull = opt_pull0
        dat = dat.reshape([*file.shape, file.channels])
        dat = utils.movedim(dat, -1, 0)

        if not options.target:
            oaffine = file.affine
            oshape = file.shape
            if options.voxel_size:
                ovx = utils.make_vector(options.voxel_size,
                                        3,
                                        dtype=oaffine.dtype)
                factor = spatial.voxel_size(oaffine) / ovx
                oaffine, oshape = spatial.affine_resize(oaffine,
                                                        oshape,
                                                        factor=factor,
                                                        anchor='f')
            grid = build_from_target(oaffine, oshape)
        mat = file.affine.to(**backend)
        imat = spatial.affine_inv(mat)
        if options.prefilter and not is_label:
            dat = spatial.spline_coeff_nd(dat, **opt_coeff)
        dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull)
        dat = utils.movedim(dat, 0, -1)

        if is_label:
            io.volumes.save(dat, ofname, like=file.fname, affine=oaffine)
        else:
            io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
Example #18
0
    def forward(self, image, **overload):
        """

        Parameters
        ----------
        image : (batch, channel, *shape) tensor
            Input image
        overload : dict
            All parameters defined at build time can be overridden at call time

        Returns
        -------
        warped : (batch, channel, *shape) tensor
            Deformed image
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """

        image = torch.as_tensor(image)
        dim = image.dim() - 2
        batch, channel, *shape = image.shape
        info = {'dtype': image.dtype, 'device': image.device}

        # get arguments
        opt_grid = {
            'dim': dim,
            'shape': shape,
            'amplitude': overload.get('vel_amplitude', self.grid.amplitude),
            'fwhm': overload.get('vel_fwhm', self.grid.fwhm),
            'bound': overload.get('vel_bound', self.grid.bound),
            'interpolation': overload.get('interpolation',
                                          self.grid.interpolation),
            'dtype': overload.get('dtype', self.grid.dtype),
            'device': overload.get('device', self.grid.device),
        }
        opt_affine = {
            'dim': dim,
            'translation': overload.get('translation',
                                        self.affine.translation),
            'rotation': overload.get('rotation', self.affine.rotation),
            'zoom': overload.get('zoom', self.affine.zoom),
            'shear': overload.get('shear', self.affine.shear),
            'dtype': overload.get('dtype', self.affine.dtype),
            'device': overload.get('device', self.affine.device),
        }
        opt_pull = {
            'bound':
            overload.get('image_bound', self.pull.bound),
            'interpolation':
            overload.get('interpolation', self.pull.interpolation),
        }

        grid = self.grid(batch, **opt_grid)
        aff = self.affine(batch, **opt_affine)

        # shift center of rotation
        aff_shift = torch.cat(
            (torch.eye(dim, **info),
             -torch.as_tensor(opt_grid['shape'], **info)[:, None] / 2),
            dim=1)
        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = matvec(lin, grid) + off

        # pull
        warped = self.pull(image, grid, **opt_pull)

        return warped, grid
Example #19
0
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None,
           interpolation=1, bound='dct2', extrapolate=False, device=None,
           verbose=True):
    """Apply the linear and non-linear components of the transform and
    reslice the image to the target space.

    Notes
    -----
    .. The shape and general orientation of the moving image is kept untouched.
    .. The linear transform is composed with the original orientation matrix.
    .. The non-linear component is "wrapped" in the input space, where it
       is applied.
    .. This function writes a new file (it does not modify the input files
       in place).

    Parameters
    ----------
    moving : ImageFile
        An object describing a moving image.
    fname : list of str
        Output filename for each input file of the moving image
        (since images can be encoded over multiple volumes)
    like : ImageFile
        An object describing the target space
    inv : bool, default=False
        True if we are warping the fixed image to the moving space.
        In the case, `moving` should be a `FixedImageFile` and `like` a
        `MovingImageFile`. Else it should be a `MovingImageFile` and `'like`
        a `FixedImageFile`.
    lin : (4, 4) tensor, optional
        Linear (or rather affine) transformation
    nonlin : dict, optional
        Non-linear displacement field, with keys:
        disp : (..., 3) tensor
            Displacement field (in voxels)
        affine : (4, 4) tensor
            Orientation matrix of the displacement field
    interpolation : int, default=1
    bound : str, default='dct2'
    extrapolate : bool, default = False
    device : default='cpu'

    """
    nonlin = nonlin or dict(disp=None, affine=None)
    prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate)

    moving_affine = moving.affine.to(device)
    fixed_affine = like.affine.to(device)

    if inv:
        # affine-corrected fixed space
        if lin is not None:
            fix2lin = affine_matmul(lin, fixed_affine)
        else:
            fix2lin = fixed_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param to moving
            nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(moving_affine, fix2lin)
            g = affine_grid(g, like.shape)

    else:
        # affine-corrected moving space
        if lin is not None:
            mov2nlin = affine_matmul(lin, moving_affine)
        else:
            mov2nlin = moving_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param voxels to moving voxels (warps moving to fixed)
            nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(mov2nlin, fixed_affine)
            g = affine_grid(g, like.shape)

    if moving.type == 'labels':
        prm['interpolation'] = 0
    for file, ofname in zip(moving.files, fname):
        if verbose:
            print(f'Resliced:   {file.fname}\n'
                  f'         -> {ofname}')
        dat = io.volumes.loadf(file.fname, rand=True, device=device)
        dat = dat.reshape([*file.shape, file.channels])
        if g is not None:
            dat = utils.movedim(dat, -1, 0)
            dat = pull(dat, g, **prm)
            dat = utils.movedim(dat, 0, -1)
        io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
Example #20
0
    def forward():
        """Forward pass up to the loss"""

        loss = 0

        # affine matrix
        A = None
        for trf in options.transformations:
            trf.update()
            if isinstance(trf, struct.Linear):
                q = trf.optdat.to(**backend)
                # print(q.tolist())
                B = trf.basis.to(**backend)
                A = linalg.expm(q, B)
                if torch.is_tensor(trf.shift):
                    # include shift
                    shift = trf.shift.to(**backend)
                    eye = torch.eye(options.dim, **backend)
                    A = A.clone()  # needed because expm is a custom autograd.Function
                    A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift)
                for loss1 in trf.losses:
                    loss += loss1.call(q)
                break

        # non-linear displacement field
        d = None
        d_aff = None
        for trf in options.transformations:
            if not trf.isfree():
                continue
            if isinstance(trf, struct.FFD):
                d = trf.dat.to(**backend)
                d = ffd_exp(d, trf.shape, returns='disp')
                for loss1 in trf.losses:
                    loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break
            elif isinstance(trf, struct.Diffeo):
                d = trf.dat.to(**backend)
                if not trf.smalldef:
                    # penalty on velocity fields
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d = spatial.exp(d[None], displacement=True)[0]
                if trf.smalldef:
                    # penalty on exponentiated transform
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break

        # loop over image pairs
        for match in options.losses:
            if not match.fixed:
                continue
            nb_levels = len(match.fixed.dat)
            prm = dict(interpolation=match.moving.interpolation,
                       bound=match.moving.bound,
                       extrapolate=match.moving.extrapolate)
            # loop over pyramid levels
            for moving, fixed in zip(match.moving.dat, match.fixed.dat):
                moving_dat, moving_aff = moving
                fixed_dat, fixed_aff = fixed

                moving_dat = moving_dat.to(**backend)
                moving_aff = moving_aff.to(**backend)
                fixed_dat = fixed_dat.to(**backend)
                fixed_aff = fixed_aff.to(**backend)

                # affine-corrected moving space
                if A is not None:
                    Ms = affine_matmul(A, moving_aff)
                else:
                    Ms = moving_aff

                if d is not None:
                    # fixed to param
                    Mt = affine_lmdiv(d_aff, fixed_aff)
                    if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]):
                        g = smalldef(d)
                    else:
                        g = affine_grid(Mt, fixed_dat.shape[1:])
                        g = g + pull_grid(d, g)
                    # param to moving
                    Ms = affine_lmdiv(Ms, d_aff)
                    g = affine_matvec(Ms, g)
                else:
                    # fixed to moving
                    Mt = fixed_aff
                    Ms = affine_lmdiv(Ms, Mt)
                    g = affine_grid(Ms, fixed_dat.shape[1:])

                # pull moving image
                warped_dat = pull(moving_dat, g, **prm)
                loss += match.call(warped_dat, fixed_dat) / float(nb_levels)

                # import matplotlib.pyplot as plt
                # plt.subplot(1, 2, 1)
                # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.subplot(1, 2, 2)
                # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.show()

        return loss
Example #21
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # 1) Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, struct.Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            trf.dat = spatial.exp(trf.dat[None],
                                  displacement=True,
                                  inverse=trf.inv)[0]
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, struct.Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]

    # 2) If the first transform is linear, compose it with the input
    #    orientation matrix
    if (options.transformations
            and isinstance(options.transformations[0], struct.Linear)):
        trf = options.transformations[0]
        for file in options.files:
            mat = file.affine.to(**backend)
            aff = trf.affine.to(**backend)
            file.affine = spatial.affine_lmdiv(aff, mat)
        options.transformations = options.transformations[1:]

    def build_from_target(target):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, struct.Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(target.affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.order)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.order
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    # 3) If target is provided, we can build most of the transform once
    #    and just multiply it with a input-wise affine matrix.
    if options.target:
        grid = build_from_target(options.target)
        oaffine = options.target.affine

    # 4) Loop across input files
    opt = dict(interpolation=options.interpolation,
               bound=options.bound,
               extrapolate=options.extrapolate)
    output = utils.make_list(options.output, len(options.files))
    for file, ofname in zip(options.files, output):
        ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
        print(f'Reslicing:   {file.fname}\n' f'          -> {ofname}')
        dat = io.volumes.loadf(file.fname, rand=True, **backend)
        dat = dat.reshape([*file.shape, file.channels])
        dat = utils.movedim(dat, -1, 0)

        if not options.target:
            grid = build_from_target(file)
            oaffine = file.affine
        mat = file.affine.to(**backend)
        imat = spatial.affine_inv(mat)
        dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt)
        dat = utils.movedim(dat, 0, -1)

        io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
Example #22
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Example #23
0
def _make_image(option, dim=None, device=None):
    """
    Load an image and build a Gaussian pyramid (if requireD)
    Returns: ImagePyramid
    """
    dat, mask, affine = _load_image(option.files,
                                    dim=dim,
                                    device=device,
                                    label=option.label)
    dim = dat.dim() - 1
    if option.mask:
        mask1 = mask
        mask, _, _ = _load_image([option.mask],
                                 dim=dim,
                                 device=device,
                                 label=option.label)
        if mask.shape[-dim:] != dat.shape[-dim:]:
            raise ValueError('Mask should have the same shape as the image. '
                             f'Got {mask.shape[-dim:]} and {dat.shape[-dim:]}')
        if mask1 is not None:
            mask = mask * mask1
        del mask1
    if option.world:  # overwrite orientation matrix
        affine = io.transforms.map(option.world).fdata().squeeze()
    for transform in (option.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        affine = spatial.affine_lmdiv(transform, affine)
    if not option.discretize and any(option.rescale):
        dat = _rescale_image(dat, option.rescale)
    if option.pad:
        pad = option.pad
        if isinstance(pad[-1], str):
            *pad, unit = pad
        else:
            unit = 'vox'
        if unit == 'mm':
            voxel_size = spatial.voxel_size(affine)
            pad = torch.as_tensor(pad, **utils.backend(voxel_size))
            pad = pad / voxel_size
            pad = pad.floor().int().tolist()
        else:
            pad = [int(p) for p in pad]
        pad = py.make_list(pad, dim)
        if any(pad):
            affine, _ = spatial.affine_pad(affine,
                                           dat.shape[-dim:],
                                           pad,
                                           side='both')
            dat = utils.pad(dat, pad, side='both', mode=option.bound)
            if mask is not None:
                mask = utils.pad(mask, pad, side='both', mode=option.bound)
    if option.fwhm:
        fwhm = option.fwhm
        if isinstance(fwhm[-1], str):
            *fwhm, unit = fwhm
        else:
            unit = 'vox'
        if unit == 'mm':
            voxel_size = spatial.voxel_size(affine)
            fwhm = torch.as_tensor(fwhm, **utils.backend(voxel_size))
            fwhm = fwhm / voxel_size
        dat = spatial.smooth(dat, dim=dim, fwhm=fwhm, bound=option.bound)
    image = objects.ImagePyramid(dat,
                                 levels=option.pyramid,
                                 affine=affine,
                                 dim=dim,
                                 bound=option.bound,
                                 mask=mask,
                                 extrapolate=option.extrapolate,
                                 method=option.pyramid_method)
    if getattr(option, 'soft_quantize', False) and len(image[0].dat) == 1:
        for level in image:
            level.preview = level.dat
            level.dat = _soft_quantize_image(level.dat, option.soft_quantize)
    elif not option.label and option.discretize:
        for level in image:
            level.preview = level.dat
            level.dat = _discretize_image(level.dat, option.discretize)
    return image
Example #24
0
def get_oriented_slice(image, dim=-1, index=None, affine=None,
                       space=None, bbox=None, interpolation=1,
                       transpose_sagittal=False, return_index=False,
                       return_mat=False):
    """Sample a slice in a RAS system

    Parameters
    ----------
    image : (..., *shape3)
    dim : int, default=-1
        Index of spatial dimension to sample in the visualization space
        If RAS: -1 = axial / -2 = coronal / -3 = sagittal
    index : int, default=shape//2
        Coordinate (in voxel) of the slice to extract
    affine : (4, 4) tensor, optional
        Orientation matrix of the image
    space : (4, 4) tensor, optional
        Orientation matrix of the visualisation space.
        Default: RAS with minimum voxel size of all inputs.
    bbox : (2, D) tensor_like, optional
        Bounding box: min and max coordinates (in millimetric
        visualisation space). Default: bounding box of the input image.
    interpolation : {0, 1}, default=1
        Interpolation order.

    Returns
    -------
    slice : (..., *shape2) tensor
        Slice in the visualisation space.

    """
    # preproc dim
    if isinstance(dim, str):
        dim = dim.lower()[0]
        if dim == 'a':
            dim = -1
        if dim == 'c':
            dim = -2
        if dim == 's':
            dim = -3

    backend = utils.backend(image)

    # compute default space (mn/mx are in voxels)
    affine, shape = _get_default_native(affine, image.shape[-3:])
    space, mn, mx = _get_default_space(affine, [shape], space, bbox)
    affine, shape = (affine[0], shape[0])

    # compute default cursor (in voxels)
    if index is None:
        index = (mx + mn) / 2
    else:
        index = torch.as_tensor(index)
        index = spatial.affine_matvec(spatial.affine_inv(space), index)

    # include slice to volume matrix
    shape = tuple(((mx-mn) + 1).round().int().tolist())
    if dim == -1:
        # axial
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 1, 0, - mn[1] + 1],
                 [0, 0, 1, - index[2]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = shape[:-1]
        index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1)
    elif dim == -2:
        # coronal
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 0, 1, - mn[2] + 1],
                 [0, 1, 0, - index[1]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = (shape[0], shape[2])
        index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1)
    elif dim == -3:
        # sagittal
        if not transpose_sagittal:
            shift = [[0, 0, 1, - mn[2] + 1],
                     [0, 1, 0, - mn[1] + 1],
                     [1, 0, 0, - index[0]],
                     [0, 0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[2], shape[1])
            index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1)
        else:
            shift = [[0, -1, 0,   mx[1] + 1],
                     [0,  0, 1, - mn[2] + 1],
                     [1,  0, 0, - index[0]],
                     [0,  0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[1], shape[2])
            index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1)
    else:
        raise ValueError(f'Unknown dimension {dim}')

    # sample
    space = spatial.affine_rmdiv(space, shift)
    affine = spatial.affine_lmdiv(affine, space)
    affine = affine.to(**backend)
    grid = spatial.affine_grid(affine, [*shape, 1])
    *channel, s0, s1, s2 = image.shape
    imshape = (s0, s1, s2)
    image = image.reshape([1, -1, *imshape])
    image = spatial.grid_pull(image, grid[None], interpolation=interpolation,
                              bound='dct2', extrapolate=False)
    image = image.reshape([*channel, *shape])
    return ((image, index, space) if return_index and return_mat else
            (image, index) if return_index else
            (image, space) if return_mat else image)