Esempio n. 1
0
 def exp2(self, jacobian=False, add_identity=False, alpha=None):
     """Exponentiate both forward and inverse transforms"""
     v = self.fdata()
     if alpha:
         v = v * alpha
     grid, igrid = spatial.shoot(v,
                                 self.kernel,
                                 steps=self.steps,
                                 factor=self.factor,
                                 voxel_size=self.voxel_size,
                                 **self.reg_prm,
                                 return_inverse=True,
                                 displacement=True)
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
         ijac = spatial.grid_jacobian(igrid, type='displacement')
         if add_identity:
             grid = self.add_identity(grid)
             igrid = self.add_identity(igrid)
         return grid, igrid, jac, ijac
     else:
         if add_identity:
             grid = self.add_identity(grid)
             igrid = self.add_identity(igrid)
         return grid, igrid
Esempio n. 2
0
 def exp2(self, v=None, jacobian=False, add_identity=False,
          cache_result=False, recompute=True):
     """Exponentiate both forward and inverse transforms"""
     if v is None:
         v = self.dat.dat
     if recompute or self._cache is None or self._icache is None:
         grid, igrid = spatial.shoot(v, self.kernel, steps=self.steps,
                                     factor=self.factor / py.prod(self.shape),
                                     voxel_size=self.voxel_size, **self.penalty,
                                     return_inverse=True, displacement=True)
     if cache_result:
         self._cache = grid
         self._icache = igrid
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
         ijac = spatial.grid_jacobian(igrid, type='displacement')
         if add_identity:
             grid = self.add_identity(grid)
             igrid = self.add_identity(igrid)
         return grid, igrid, jac, ijac
     else:
         if add_identity:
             grid = self.add_identity(grid)
             igrid = self.add_identity(igrid)
         return grid, igrid
Esempio n. 3
0
 def exp2(self, jacobian=False, add_identity=False, alpha=None):
     """Exponentiate both forward and inverse transforms"""
     grid = self.fdata()
     if alpha:
         grid = grid * alpha
     igrid = -grid
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
         ijac = spatial.grid_jacobian(igrid, type='displacement')
     if add_identity:
         grid = self.add_identity(grid)
         igrid = self.add_identity(igrid)
     return (grid, igrid, jac, ijac) if jacobian else (grid, igrid)
Esempio n. 4
0
 def exp2(self, v=None, jacobian=False, add_identity=False,
         cache_result=False, recompute=True):
     """Exponentiate both forward and inverse transforms"""
     if v is None:
         v = self.dat.dat
     grid = v
     if recompute or self._icache is None:
         igrid = spatial.grid_inv(v, type='disp', **self.penalty)
     else:
         igrid = self._icache
     if cache_result:
         self._icache = igrid
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
         ijac = spatial.grid_jacobian(igrid, type='displacement')
     if add_identity:
         grid = self.add_identity(grid)
         igrid = self.add_identity(igrid)
     return (grid, igrid, jac, ijac) if jacobian else (grid, igrid)
Esempio n. 5
0
 def iexp(self, jacobian=False, add_identity=False, alpha=None):
     """Exponentiate inverse transform"""
     grid = -self.fdata()
     if alpha:
         grid = grid * alpha
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
     if add_identity:
         grid = self.add_identity(grid)
     return (grid, jac) if jacobian else grid
Esempio n. 6
0
 def exp(self, v=None, jacobian=False, add_identity=False,
         cache_result=False, recompute=True):
     """Exponentiate forward transform"""
     if v is None:
         v = self.dat.dat
     grid = v
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
     if add_identity:
         grid = self.add_identity(grid)
     return (grid, jac) if jacobian else grid
Esempio n. 7
0
 def iexp(self, v=None, jacobian=False, add_identity=False,
          cache_result=False, recompute=True):
     """Exponentiate inverse transform"""
     if v is None:
         v = self.dat.dat
     if recompute or self._icache is None:
         grid = spatial.grid_inv(v, type='disp', **self.penalty)
     else:
         grid = self._icache
     if cache_result:
         self._icache = grid
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
     if add_identity:
         grid = self.add_identity(grid)
     return (grid, jac) if jacobian else grid
Esempio n. 8
0
 def exp(self, jacobian=False, add_identity=False, alpha=None):
     """Exponentiate forward transform"""
     v = self.fdata()
     if alpha:
         v = v * alpha
     grid = spatial.shoot(v,
                          self.kernel,
                          steps=self.steps,
                          factor=self.factor,
                          voxel_size=self.voxel_size,
                          **self.reg_prm,
                          displacement=True)
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
     if add_identity:
         grid = self.add_identity(grid)
     return (grid, jac) if jacobian else grid
Esempio n. 9
0
def _rotate_grad(grad, aff=None, dense=None):
    """Rotate grad by the jacobian of `aff o dense`.
    grad : (..., dim) tensor       Spatial gradients
    aff : (dim+1, dim+1) tensor    Affine matrix
    dense : (..., dim) tensor      Dense vox2vox displacement field
    returns : (..., dim) tensor    Rotated gradients.
    """
    if aff is None and dense is None:
        return grad
    dim = grad.shape[-1]
    if dense is not None:
        jac = spatial.grid_jacobian(dense, type='disp')
        if aff is not None:
            jac = torch.matmul(aff[:dim, :dim], jac)
    else:
        jac = aff[:dim, :dim]
    grad = linalg.matvec(jac.transpose(-1, -2), grad)
    return grad
Esempio n. 10
0
    def pull_grad(self, grid, rotate=False):
        """Sample the image gradients at dense coordinates.

        Parameters
        ----------
        grid : (*spatial, dim) tensor or None
            Dense transformation field.
        rotate : bool, default=False
            Rotate the gradients using the Jacobian of the transformation.

        Returns
        -------
        grad : ([C], *spatial, dim) tensor

        """
        if grid is None:
            return self.grad()
        grad = spatial.grid_grad(self.dat, grid, bound=self.bound,
                                 extrapolate=self.extrapolate)
        if rotate:
            jac = spatial.grid_jacobian(grid)
            jac = jac.transpose(-1, -2)
            grad = linalg.matvec(jac, grad)
        return grad
Esempio n. 11
0
    def do_affine(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is not None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff_pos = self.affine.position[0].lower()
        if any(loss.backward for loss in self.losses):
            aff0, iaff0, gaff0, igaff0 = \
                self.affine.exp2(logaff0, grad=True,
                                 cache_result=not in_line_search)
            phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False)
        else:
            iaff0, igaff0, iphi0 = None, None, None
            aff0, gaff0 = self.affine.exp(logaff0, grad=True,
                                          cache_result=not in_line_search)
            phi0 = self.nonlin.exp(cache_result=True, recompute=False)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, gaff00 = iphi0, iaff0, igaff0
            else:
                phi00, aff00, gaff00 = phi0, aff0, gaff0

            # ----------------------------------------------------------
            # build left and right affine matrices
            # ----------------------------------------------------------
            aff_right, gaff_right = fixed.affine, None
            if aff_pos in 'fs':
                gaff_right = gaff00 @ aff_right
                gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right)
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left, gaff_left = self.nonlin.affine, None
            if aff_pos in 'ms':
                gaff_left = gaff00 @ aff_left
                gaff_left = linalg.lmdiv(moving.affine, gaff_left)
                aff_left = aff00 @ aff_left
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                right = None
                phi = spatial.add_identity_grid(phi00)
            else:
                right = spatial.affine_grid(aff_right, fixed.shape)
                phi = regutils.smart_pull_grid(phi00, right)
                phi += right
            phi_right = phi
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                left = None
            else:
                left = spatial.affine_grid(aff_left, self.nonlin.shape)
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            if grad or hess:
                g0, g = g, None
                h0, h = h, None
                if aff_pos in 'ms':
                    g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape)
                    h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape)
                    mugrad = moving.pull_grad(left, rotate=False)
                    g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left)
                    g, h = g_left, h_left
                if aff_pos in 'fs':
                    g_right, h_right = g0, h0
                    mugrad = moving.pull_grad(phi, rotate=False)
                    jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False)
                    jac = torch.matmul(aff_left[:-1, :-1], jac)
                    mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad)
                    g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right)
                    g = g_right if g is None else g.add_(g_right)
                    h = h_right if h is None else h.add_(h_right)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        sumloss += self.llv
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            ll = sumloss.item()
            llv = self.llv
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]