Пример #1
0
    def responsibilities(image, means, precisions, proportions):
        # aliases
        x = image
        m = means
        A = precisions
        p = proportions
        nb_dim = image.dim() - 2
        del image, means, precisions, proportions

        # voxel-wise term
        x = channel2last(x).unsqueeze(-2)  # [B, ...,  1, C]
        p = unsqueeze(p, dim=1, ndim=nb_dim)  # [B, ones, K]
        m = unsqueeze(m, dim=1, ndim=nb_dim)  # [B, ones, K, C]
        A = unsqueeze(A, dim=1, ndim=nb_dim)  # [B, ones, K, C, C]
        x = x - m
        z = matvec(A, x)
        z = (z * x).sum(dim=-1)  # [B, ..., K]
        z = -0.5 * z

        # constant term
        twopi = torch.as_tensor(2 * pi, dtype=A.dtype, device=A.device)
        nrm = torch.logdet(A) - A.shape[-1] * twopi.log()
        nrm = 0.5 * nrm + p.log()
        z = z + nrm

        # softmax
        z = last2channel(z)
        logz = torch.nn.functional.log_softmax(z, dim=1)
        z = torch.nn.functional.softmax(z, dim=1)

        return z, logz
Пример #2
0
    def forward(self, x, v=None):

        dim = x.dim() - 2
        if dim not in (2, 3):
            raise ValueError(f'{type(self).__name__} only implemented '
                             f'in 2D or 3D.')

        radii = self.radii.to(**utils.backend(x))
        pradii = self.pradii.to(**utils.backend(x)).log()

        # compute joint log-likelihood `ln p(x, radius | v)`
        loss = x.new_zeros([len(radii), *x.shape])
        for i, (p, r) in enumerate(zip(pradii, radii)):
            # compute unsorted eigenvalues
            e = spatial.hessian_eig(x, r, dim=dim, sort=None)
            # soft sort
            P = math.softsort(e.abs(), tau=self.tau_sort, descending=True)
            e = linalg.matvec(P, e)
            e = utils.movedim(e, -1, 0)
            # compute penalties
            loss[i] = -self.tau_large * e[1:].sum(0)  # white ridges
            e = e.square().clamp_min_(1e-32).log()
            if dim == 3:
                loss[i] += self.tau_ratio1 * (e[1] - e[2])  # tubes
            loss[i] += self.tau_ratio0 * (e[1] - e[0])  # not plates
            loss[i] += p  # radius prior

        # compute (stable) log-sum-exp (== model evidence `ln p(x | v)`)
        loss = math.logsumexp(loss, dim=0)

        # weight by probability to be a vessel and return `E_v[ln p(x | v)]`
        if v is None:
            v = x
        return -(loss * v).sum() / (v.sum() + 1e-3)
Пример #3
0
def affine_grid(mat, shape):
    """Create a dense transformation grid from an affine matrix.

    Parameters
    ----------
    mat : (..., D[+1], D[+1]) tensor
        Affine matrix (or matrices).
    shape : (D,) sequence[int]
        Shape of the grid, with length D.

    Returns
    -------
    grid : (..., *shape, D) tensor
        Dense transformation grid

    """
    mat = torch.as_tensor(mat)
    shape = list(shape)
    nb_dim = mat.shape[-1] - 1
    if nb_dim != len(shape):
        raise ValueError('Dimension of the affine matrix ({}) and shape ({}) '
                         'are not the same.'.format(nb_dim, len(shape)))
    if mat.shape[-2] not in (nb_dim, nb_dim + 1):
        raise ValueError(
            'First argument should be matrces of shape '
            '(..., {0}, {1}) or (..., {1], {1}) but got {2}.'.format(
                nb_dim, nb_dim + 1, mat.shape))
    batch_shape = mat.shape[:-2]
    grid = identity_grid(shape, mat.dtype, mat.device)
    grid = utils.unsqueeze(grid, dim=0, ndim=len(batch_shape))
    mat = utils.unsqueeze(mat, dim=-3, ndim=nb_dim)
    lin = mat[..., :nb_dim, :nb_dim]
    off = mat[..., :nb_dim, -1]
    grid = linalg.matvec(lin, grid) + off
    return grid
Пример #4
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch shape.

        Other Parameters
        ----------------
        shape : sequence[int], optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """
        shape = overload.get('shape', self.grid.velocity.field.shape)
        dtype = overload.get('dtype', self.grid.velocity.field.dtype)
        device = overload.get('device', self.grid.velocity.field.device)
        backend = dict(dtype=dtype, device=device)

        if self.grid.velocity.field.amplitude == 0:
            grid = identity_grid(shape, **backend)
        else:
            grid = self.grid(batch, shape=shape, **backend)
        dtype = grid.dtype
        device = grid.device
        backend = dict(dtype=dtype, device=device)

        shape = grid.shape[1:-1]
        dim = len(shape)
        aff = self.affine(batch, dim=dim, **backend)

        # shift center of rotation
        aff_shift = torch.cat((
            torch.eye(dim, **backend),
            torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)),
            dim=1)
        aff_shift = as_euclidean(aff_shift)

        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = utils.unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = linalg.matvec(lin, grid) + off

        return grid
Пример #5
0
    def exp(self, velocity, affine=None, displacement=False):
        """Generate a deformation grid from tangent parameters.

        Parameters
        ----------
        velocity : (batch, *spatial, nb_dim)
            Stationary velocity field
        affine : (batch, nb_prm)
            Affine parameters
        displacement : bool, default=False
            Return a displacement field (voxel to shift) rather than
            a transformation field (voxel to voxel).

        Returns
        -------
        grid : (batch, *spatial, nb_dim)
            Deformation grid (transformation or displacment).

        """
        info = {'dtype': velocity.dtype, 'device': velocity.device}

        # generate grid
        shape = velocity.shape[1:-1]
        velocity_small = self.resize(velocity, type='displacement')
        grid = self.velexp(velocity_small)
        grid = self.resize(grid, shape=shape, type='grid')

        if affine is not None:
            # exponentiate
            affine_prm = affine
            affine = []
            for prm in affine_prm:
                affine.append(self.affexp(prm))
            affine = torch.stack(affine, dim=0)

            # shift center of rotation
            affine_shift = torch.cat(
                (torch.eye(self.dim, **info),
                 -torch.as_tensor(shape, **info)[:, None] / 2),
                dim=1)
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

            # compose
            affine = unsqueeze(affine, dim=-3, ndim=self.dim)
            lin = affine[..., :self.dim, :self.dim]
            off = affine[..., :self.dim, -1]
            grid = matvec(lin, grid) + off

        if displacement:
            grid = grid - spatial.identity_grid(grid.shape[1:-1], **info)

        return grid
Пример #6
0
def se_sample_svd(shape, sigma, lam, mu=None, repeats=1, **backend):
    """Sample random fields with a squared exponential kernel.

    This function computes the square root of the covariance matrix by SVD.

    Parameters
    ----------
    shape : sequence[int]
        Shape of the image / volume.å
    sigma : float
        SE amplitude.
    lam : float
        SE length-scale.
    mu : () or (*shape) tensor_like
        SE mean
    repeats : int, default=1
        Number of sampled fields.

    Returns
    -------
    field : (repeats, *shape) tensor
        Sampled random fields.

    """
    # Build SE covariance matrix
    e = dist_map(shape, **backend)
    backend = utils.backend(e)
    e.mul_(-0.5 / (lam**2)).exp_().mul_(sigma**2)

    # import matplotlib.pyplot as plt
    # plt.imshow(e)
    # plt.colorbar()
    # plt.title('true cov')
    # plt.show()

    # SVD of covariance
    u, s, _ = torch.svd(e)
    s = s.sqrt_()

    # Sample white noise and apply transform
    full_shape = (repeats, *shape)
    field = torch.randn(full_shape, **backend).reshape([repeats, -1])
    field = linalg.matvec(u, field.mul_(s))
    field = field.reshape(full_shape)

    # Add mean
    if mu is not None:
        mu = torch.as_tensor(mu, **backend)
        field += mu

    return field
Пример #7
0
def cc_sample(shape, sigma, alpha, mu=None, repeats=1, **backend):
    """Sample random fields with a constant correlation.

    This function computes the square root of the covariance matrix by SVD.

    Parameters
    ----------
    shape : sequence[int]
        Shape of the image / volume.å
    sigma : float
        Variance.
    alpha : float
        Correlation.
    mu : () or (*shape) tensor_like
        SE mean
    repeats : int, default=1
        Number of sampled fields.

    Returns
    -------
    field : (repeats, *shape) tensor
        Sampled random fields.

    """
    # Build SE covariance matrix
    n = py.prod(shape)
    e = torch.full([n, n], alpha, **backend)
    e.diagonal(0, -1, -2).add_(1 - alpha)
    backend = utils.backend(e)

    # SVD of covariance
    u, s, _ = torch.svd(e)
    s = s.sqrt_()

    # Sample white noise and apply transform
    full_shape = (repeats, *shape)
    field = torch.randn(full_shape, **backend).reshape([repeats, -1])
    field = linalg.matvec(u, field.mul_(s))
    field.mul_(sigma)
    field = field.reshape(full_shape)

    # Add mean
    if mu is not None:
        mu = torch.as_tensor(mu, **backend)
        field += mu

    return field
Пример #8
0
def _rotate_grad(grad, aff=None, dense=None):
    """Rotate grad by the jacobian of `aff o dense`.
    grad : (..., dim) tensor       Spatial gradients
    aff : (dim+1, dim+1) tensor    Affine matrix
    dense : (..., dim) tensor      Dense vox2vox displacement field
    returns : (..., dim) tensor    Rotated gradients.
    """
    if aff is None and dense is None:
        return grad
    dim = grad.shape[-1]
    if dense is not None:
        jac = spatial.grid_jacobian(dense, type='disp')
        if aff is not None:
            jac = torch.matmul(aff[:dim, :dim], jac)
    else:
        jac = aff[:dim, :dim]
    grad = linalg.matvec(jac.transpose(-1, -2), grad)
    return grad
Пример #9
0
    def nll(image, resp, means, precisions):
        # aliases
        x = image
        z = resp
        m = means
        A = precisions
        nb_dim = image.dim() - 2
        del image, resp, means, precisions

        x = channel2last(x).unsqueeze(-2)  # [B, ...,  1, C]
        z = channel2last(z)  # [B, ..., K]
        m = unsqueeze(m, dim=1, ndim=nb_dim)  # [B, ones, K, C]
        A = unsqueeze(A, dim=1, ndim=nb_dim)  # [B, ones, K, C, C]
        x = x - m
        loss = matvec(A, x)
        loss = (loss * x).sum(dim=-1)  # [B, ..., K]
        loss = (loss * z).sum(dim=-1)  # [B, ...]
        loss = loss * 0.5
        return loss
Пример #10
0
def jg(jac, grad, dim=None):
    """Jacobian-gradient product: J*g

    Parameters
    ----------
    jac : (..., K, *spatial, D)
    grad : (..., K, *spatial)

    Returns
    -------
    new_grad : (..., *spatial, D)

    """
    if grad is None:
        return None
    dim = dim or (grad.dim() - 1)
    grad = utils.movedim(grad, -dim - 1, -1)
    jac = utils.movedim(jac, -dim - 2, -1)
    grad = linalg.matvec(jac, grad)
    return grad
Пример #11
0
 def compose_grad(g, h, g_mu, g_aff):
     """
     g, h : gradient/Hessian of loss wrt moving image
     g_mu : spatial gradients of moving image
     g_aff : gradient of affine matrix wrt Lie parameters
     returns g, h: gradient/Hessian of loss wrt Lie parameters
     """
     # Note that `h` can be `None`, but the functions I
     # use deal with this case correctly.
     dim = g_mu.shape[-1]
     g = jg(g_mu, g)
     h = jhj(g_mu, h)
     g, h = regutils.affine_grid_backward(g, h)
     dim2 = dim * (dim + 1)
     g = g.reshape([*g.shape[:-2], dim2])
     g_aff = g_aff[..., :-1, :]
     g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
     g = linalg.matvec(g_aff, g)
     if h is not None:
         h = h.reshape([*h.shape[:-4], dim2, dim2])
         h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
         # h = h.abs().sum(-1).diag_embed()
     return g, h
Пример #12
0
    def pull_grad(self, grid, rotate=False):
        """Sample the image gradients at dense coordinates.

        Parameters
        ----------
        grid : (*spatial, dim) tensor or None
            Dense transformation field.
        rotate : bool, default=False
            Rotate the gradients using the Jacobian of the transformation.

        Returns
        -------
        grad : ([C], *spatial, dim) tensor

        """
        if grid is None:
            return self.grad()
        grad = spatial.grid_grad(self.dat, grid, bound=self.bound,
                                 extrapolate=self.extrapolate)
        if rotate:
            jac = spatial.grid_jacobian(grid)
            jac = jac.transpose(-1, -2)
            grad = linalg.matvec(jac, grad)
        return grad
Пример #13
0
def exp_backward(vel,
                 *grad_and_hess,
                 inverse=False,
                 steps=8,
                 interpolation='linear',
                 bound='dft',
                 rotate_grad=False):
    """Backward pass of SVF exponentiation.

    This should be much more memory-efficient than the autograd pass
    as we don't have to store intermediate grids.

    I am using DARTEL's derivatives (from the code, not the paper).
    From what I get, it corresponds to pushing forward the gradient
    (computed in observation space) recursively while squaring the
    (inverse) transform.
    Remember that the push forward of g by phi is
                    |iphi| iphi' * g(iphi)
    where iphi is the inverse of phi. We could also have implemented
    this operation as: inverse(phi)' * push(g, phi), since
    push(g, phi) \approx |iphi| g(iphi). It has the advantage of using
    push rather than pull, which might preserve better positive-definiteness
    of the Hessian, but requires the inversion of (potentially ill-behaved)
    Jacobian matrices.

    Note that gradients must first be rotated using the Jacobian of
    the exponentiated transform so that the denominator refers to the
    initial velocity (we want dL/dV0, not dL/dPsi).
    THIS IS NOT DONE INSIDE THIS FUNCTION YET (see _dartel).

    Parameters
    ----------
    vel : (..., *spatial, dim) tensor
        Velocity
    grad : (..., *spatial, dim) tensor
        Gradient with respect to the output grid
    hess : (..., *spatial, dim*(dim+1)//2) tensor, optional
        Symmetric hessian with respect to the output grid.
    inverse : bool, default=False
        Whether the grid is an inverse
    steps : int, default=8
        Number of scaling and squaring steps
    interpolation : str or int, default='linear'
    bound : str, default='dft'
    rotate_grad : bool, default=False
        If True, rotate the gradients using the Jacobian of exp(vel).

    Returns
    -------
    grad : (..., *spatial, dim) tensor
        Gradient with respect to the SVF
    hess : (..., *spatial, dim*(dim+1)//2) tensor, optional
        Approximate (block diagonal) Hessian with respect to the SVF

    """
    has_hess = len(grad_and_hess) > 1
    grad, *hess = grad_and_hess
    hess = hess[0] if hess else None
    del grad_and_hess

    opt = dict(bound=bound, interpolation=interpolation)
    dim = vel.shape[-1]
    shape = vel.shape[-dim - 1:-1]
    id = identity_grid(shape, **utils.backend(vel))
    vel = vel.clone()

    if rotate_grad:
        # It forces us to perform a forward exponentiation, which
        # is a bit annoying...
        # Maybe save the Jacobian after the forward pass? But it take space
        _, jac = exp_forward(vel,
                             jacobian=True,
                             steps=steps,
                             displacement=True,
                             **opt,
                             _anagrad=True)
        jac = jac.transpose(-1, -2)
        grad = linalg.matvec(jac, grad)
        if hess is not None:
            hess = _jhj(jac, hess)
        del jac

    vel /= (-1 if not inverse else 1) * (2**steps)
    jac = grid_jacobian(vel, bound=bound, type='disp')
    for _ in range(steps):
        det = jac.det()
        jac = jac.transpose(-1, -2)
        grad0 = grad
        grad = _pull_vel(grad, id + vel, **opt)  # \
        grad = linalg.matvec(jac, grad)  # | push forward
        grad *= det[..., None]  # /
        grad += grad0  # add all scales (SVF)
        if hess is not None:
            hess0 = hess
            hess = _pull_vel(hess, id + vel, **opt)
            hess = _jhj(jac, hess)
            hess *= det[..., None]
            hess += hess0
        # squaring
        jac = jac.transpose(-1, -2)
        jac = _composition_jac(jac, vel, type='disp', identity=id, **opt)
        vel += _pull_vel(vel, id + vel, **opt)

    if inverse:
        grad.neg_()

    grad /= (2**steps)
    if hess is not None:
        hess /= (2**steps)

    return (grad, hess) if has_hess else grad
Пример #14
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Пример #15
0
def shoot(vel,
          greens=None,
          absolute=_default_absolute,
          membrane=_default_membrane,
          bending=_default_bending,
          lame=_default_lame,
          factor=1,
          voxel_size=1,
          return_inverse=False,
          displacement=False,
          steps=8,
          fast=True,
          verbose=False):
    """Exponentiate a velocity field by geodesic shooting.

    Notes
    -----
    .. This function generates the *inverse* deformation, if we follow
       LDDMM conventions. This allows the velocity to be defined in the
       space of the moving image.
    .. If the greens function is provided, the penalty parameters are
       not used.

    Parameters
    ----------
    vel : ([batch], *spatial, dim) tensor
        Initial velocity in moving space.
    greens : (*spatial, [dim, dim]) tensor, optional
        Greens function generated by `greens`.
    absolute : float, default=0.0001
        Penalty on absolute displacements.
    membrane : float, default=0.001
        Penalty on the membrane energy.
    bending : float, default=0.2
        Penalty on the bending energy.
    lame : float or (float, float), default=(0.05, 0.2)
        Linear elastic penalty.
    voxel_size : [sequence of] float, default=1
        Needed when greens is provided if lame == 0.
    return_inverse : bool, default=False
        Return the inverse on top of the forward transform.
    displacement : bool, default=False
        Return a displacement field instead of a transformation field.
    steps : int, default=8
        Number of integration steps.
        If None, use an educated guess based on the magnitude of `vel`.
    fast : bool, default=True
        If True, use a faster integration scheme, which may induce
        some numerical error (the energy is not exactly preserved
        along time). Else, use the slower but more precise scheme.

    Returns
    -------
    grid : ([batch], *spatial, dim) tensor
        Transformation from fixed to moving space.
        (It is used to warp a moving image to a fixed one).

    igrid : ([batch], *spatial, dim) tensor, if return_inverse
        Inverse transformation, from fixed to moving space.
        (It is used to warp a fixed image to a moving one).

    """
    # Authors
    # -------
    # .. John Ashburner <*****@*****.**> : original Matlab code
    # .. Yael Balbastre <*****@*****.**> : Python port
    #
    # License
    # -------
    # The original Matlab code is (C) 2012-2019 WCHN / John Ashburner
    # and was distributed as part of [SPM](https://www.fil.ion.ucl.ac.uk/spm)
    # under the GNU General Public Licence (version >= 2).

    vel = torch.as_tensor(vel)
    backend = utils.backend(vel)
    dim = vel.shape[-1]
    spatial = vel.shape[-dim - 1:-1]

    prm = dict(absolute=absolute,
               membrane=membrane,
               bending=bending,
               lame=lame,
               voxel_size=voxel_size,
               factor=factor)
    pull_prm = dict(bound='dft', interpolation=1, extrapolate=True)
    if greens is None:
        greens = _greens(spatial, **prm, **backend)
    greens = torch.as_tensor(greens, **backend)

    if not steps:
        # Number of time steps from an educated guess about how far to move
        with torch.no_grad():
            steps = vel.square().sum(
                dim=-1).max().sqrt().floor().int().item() + 1

    id = identity_grid(spatial, **backend)
    mom = mom0 = regulariser_grid(vel, **prm, bound='dft')
    vel = vel / steps
    disp = -vel
    if return_inverse or not fast:
        idisp = vel.clone()

    for i in range(1, abs(steps)):
        if fast:
            # JA: the update of u_t is not exactly as described in the paper,
            # but describing this might be a bit tricky. The approach here
            # was the most stable one I could find - although it does lose some
            # energy as < v_t, u_t> decreases over time steps.
            jac = _jacobian(-vel)
            mom = linalg.matvec(jac.transpose(-1, -2), mom)
            mom = _push_grid(mom, id + vel, **pull_prm)
        else:
            jac = _jacobian(idisp).inverse()
            mom = linalg.matvec(jac.transpose(-1, -2), mom0)
            mom = _push_grid(mom, id + idisp, **pull_prm)

        # Convolve with Greens function of L
        # v_t \gets L^g u_t
        vel = greens_apply(mom, greens, factor=factor, voxel_size=voxel_size)
        vel = vel.div_(steps)
        if verbose:
            print(f'{0.5*steps*(vel*mom).sum().item()/py.prod(spatial):6g}',
                  end='\n' if not (i % 5) else ' ',
                  flush=True)

        # $\psi \gets \psi \circ (id - \tfrac{1}{T} v)$
        # JA: I found that simply using
        # $\psi \gets \psi - \tfrac{1}{T} (D \psi) v$ was not so stable.
        disp = _pull_grid(disp, id - vel, **pull_prm).sub_(vel)
        if return_inverse or not fast:
            idisp += _pull_grid(vel, id + idisp, **pull_prm)

    if verbose:
        print('')
    if not displacement:
        disp += id
        if return_inverse:
            idisp += id
    return (disp, idisp) if return_inverse else disp
Пример #16
0
    def do_affine(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is not None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff_pos = self.affine.position[0].lower()
        if any(loss.backward for loss in self.losses):
            aff0, iaff0, gaff0, igaff0 = \
                self.affine.exp2(logaff0, grad=True,
                                 cache_result=not in_line_search)
            phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False)
        else:
            iaff0, igaff0, iphi0 = None, None, None
            aff0, gaff0 = self.affine.exp(logaff0, grad=True,
                                          cache_result=not in_line_search)
            phi0 = self.nonlin.exp(cache_result=True, recompute=False)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, gaff00 = iphi0, iaff0, igaff0
            else:
                phi00, aff00, gaff00 = phi0, aff0, gaff0

            # ----------------------------------------------------------
            # build left and right affine matrices
            # ----------------------------------------------------------
            aff_right, gaff_right = fixed.affine, None
            if aff_pos in 'fs':
                gaff_right = gaff00 @ aff_right
                gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right)
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left, gaff_left = self.nonlin.affine, None
            if aff_pos in 'ms':
                gaff_left = gaff00 @ aff_left
                gaff_left = linalg.lmdiv(moving.affine, gaff_left)
                aff_left = aff00 @ aff_left
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                right = None
                phi = spatial.add_identity_grid(phi00)
            else:
                right = spatial.affine_grid(aff_right, fixed.shape)
                phi = regutils.smart_pull_grid(phi00, right)
                phi += right
            phi_right = phi
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                left = None
            else:
                left = spatial.affine_grid(aff_left, self.nonlin.shape)
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            if grad or hess:
                g0, g = g, None
                h0, h = h, None
                if aff_pos in 'ms':
                    g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape)
                    h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape)
                    mugrad = moving.pull_grad(left, rotate=False)
                    g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left)
                    g, h = g_left, h_left
                if aff_pos in 'fs':
                    g_right, h_right = g0, h0
                    mugrad = moving.pull_grad(phi, rotate=False)
                    jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False)
                    jac = torch.matmul(aff_left[:-1, :-1], jac)
                    mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad)
                    g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right)
                    g = g_right if g is None else g.add_(g_right)
                    h = h_right if h is None else h.add_(h_right)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        sumloss += self.llv
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            ll = sumloss.item()
            llv = self.llv
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Пример #17
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size
        overload : dict

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        mean = overload.get('mean', self.mean)
        voxel_size = overload.get('voxel_size', self.voxel_size)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # sample if parameters are callable
        nb_dim = len(shape)
        voxel_size = utils.make_vector(voxel_size, nb_dim, **backend)
        voxel_size = voxel_size.tolist()
        lame = py.make_list(self.lame, 2)

        if (hasattr(self, '_greens')
                and self._voxel_size == voxel_size
                and self._shape == shape):
            greens = self._greens.to(dtype=dtype, device=device)
        else:
            greens = spatial.greens(
                shape,
                absolute=self.absolute,
                membrane=self.membrane,
                bending=self.bending,
                lame=self.lame,
                voxel_size=voxel_size,
                device=device,
                dtype=dtype)
            if any(lame):
                greens, scale, _ = torch.svd(greens)
                scale = scale.sqrt_()
                greens *= scale.unsqueeze(-1)
            else:
                greens = greens.sqrt_()

            if self.cache_greens:
                self._greens = greens
                self._voxel_size = voxel_size
                self._shape = shape

        sample = torch.randn([2, batch, *shape, nb_dim], **backend)

        # multiply by square root of greens
        if greens.dim() > nb_dim:  # lame
            sample = linalg.matvec(greens, sample)
        else:
            sample = sample * greens.unsqueeze(-1)
            voxel_size = utils.make_vector(voxel_size, nb_dim, **backend)
            sample = sample / voxel_size.sqrt()
        sample = fft.complex(sample[0], sample[1])

        # inverse Fourier transform
        dims = list(range(-nb_dim-1, -1))
        sample = fft.real(fft.ifftn(sample, dim=dims))
        sample *= py.prod(shape)

        # add mean
        sample += mean

        return sample
Пример #18
0
def shim(fmap,
         max_order=2,
         mask=None,
         isocenter=None,
         dim=None,
         returns='corrected'):
    """Subtract a linear combination of spherical harmonics that minimize gradients

    Parameters
    ----------
    fmap : (..., *spatial) tensor
        Field map
    max_order : int, default=2
        Maximum order of the spherical harmonics
    mask : tensor, optional
        Mask of voxels to include (typically brain mask)
    isocenter : [sequence of] float, default=shape/2
        Coordinate of isocenter, in voxels
    dim : int, default=fmap.dim()
        Number of spatial dimensions
    returns : combination of {'corrected', 'correction', 'parameters'}, default='corrected'
        Components to return

    Returns
    -------
    corrected : (..., *spatial) tensor, if 'corrected' in `returns`
        Corrected field map (with spherical harmonics subtracted)
    correction : (..., *spatial) tensor, if 'correction' in `returns`
        Linear combination of spherical harmonics.
    parameters : (..., k) tensor, if 'parameters' in `returns`
        Parameters of the linear combination

    """
    fmap = torch.as_tensor(fmap)
    dim = dim or fmap.dim()
    shape = fmap.shape[-dim:]
    batch = fmap.shape[:-dim]
    backend = utils.backend(fmap)
    dims = list(range(-dim, 0))

    if mask is not None:
        mask = ~mask  # make it a mask of background voxels

    # compute gradients
    gmap = diff(fmap, dim=dims, side='f', bound='dct2')
    if mask is not None:
        gmap[..., mask, :] = 0
    gmap = gmap.reshape([*batch, -1])

    # compute basis of spherical harmonics
    basis = []
    for i in range(1, max_order + 1):
        b = spherical_harmonics(shape, i, isocenter, **backend)
        b = utils.movedim(b, -1, 0)
        b = diff(b, dim=dims, side='f', bound='dct2')
        if mask is not None:
            b[..., mask, :] = 0
        b = b.reshape([b.shape[0], *batch, -1])
        basis.append(b)
    basis = torch.cat(basis, 0)
    basis = utils.movedim(basis, 0, -1)  # (*batch, vox*dim, k)

    # solve system
    prm = linalg.lmdiv(basis, gmap[..., None], method='pinv')[..., 0]
    # > (*batch, k)

    # rebuild basis (without taking gradients)
    basis = []
    for i in range(1, max_order + 1):
        b = spherical_harmonics(shape, i, isocenter, **backend)
        b = utils.movedim(b, -1, 0)
        b = b.reshape([b.shape[0], *batch, *shape])
        basis.append(b)
    basis = torch.cat(basis, 0)
    basis = utils.movedim(basis, 0, -1)  # (*batch, vox*dim, k)

    comb = linalg.matvec(basis.unsqueeze(-2), utils.unsqueeze(prm, -2, dim))
    comb = comb[..., 0]
    fmap = fmap - comb

    returns = returns.split('+')
    out = []
    for ret in returns:
        if ret == 'corrected':
            out.append(fmap)
        elif ret == 'correction':
            out.append(comb)
        elif ret[0] == 'p':
            out.append(prm)
    return out[0] if len(out) == 1 else tuple(out)
Пример #19
0
def write_outputs(z, prm, options):

    # prepare filenames
    ref_native = options.input[0]
    ref_mni = options.tpm[0] if options.tpm else path_spm_prior()
    format_dict = get_format_dict(ref_native, options.output)

    # move channels to back
    backend = utils.backend(z)
    if (options.nobias_nat or options.nobias_mni or options.nobias_wrp
            or options.all_nat or options.all_mni or options.all_wrp):
        dat, _, affine = get_data(options.input, options.mask, None, 3,
                                  **backend)

    # --- native space -------------------------------------------------

    if options.prob_nat or options.all_nat:
        fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('prob.nat     ->', fname)
        io.savef(torch.movedim(z, 0, -1),
                 fname,
                 like=ref_native,
                 dtype='float32')

    if options.labels_nat or options.all_nat:
        fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('labels.nat   ->', fname)
        io.save(z.argmax(0), fname, like=ref_native, dtype='int16')

    if (options.bias_nat or options.all_nat) and options.bias:
        bias = prm['bias']
        fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('bias.nat     ->', fname)
            io.savef(torch.movedim(bias, 0, -1),
                     fname,
                     like=ref_native,
                     dtype='float32')
        else:
            for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'bias.nat.{c+1}   ->', fname)
                io.savef(bias1, fname, like=ref1, dtype='float32')
        del bias

    if (options.nobias_nat or options.all_nat) and options.bias:
        nobias = dat * prm['bias']
        fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('nobias.nat   ->', fname)
            io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native)
        else:
            for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'nobias.nat.{c+1} ->', fname)
                io.savef(nobias1, fname, like=ref1)
        del nobias

    if (options.warp_nat or options.all_nat) and options.warp:
        warp = prm['warp']
        fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.nat     ->', fname)
        io.savef(warp, fname, like=ref_native, dtype='float32')

    # --- MNI space ----------------------------------------------------
    if options.tpm is False:
        # No template -> no MNI space
        return

    fref = io.map(ref_mni)
    mni_affine, mni_shape = fref.affine, fref.shape[:3]
    dat_affine = io.map(ref_native).affine
    mni_affine = mni_affine.to(**backend)
    dat_affine = dat_affine.to(**backend)
    prm_affine = prm['affine'].to(**backend)
    dat_affine = prm_affine @ dat_affine
    if options.mni_vx:
        vx = spatial.voxel_size(mni_affine)
        scl = vx / options.mni_vx
        mni_affine, mni_shape = spatial.affine_resize(mni_affine,
                                                      mni_shape,
                                                      scl,
                                                      anchor='f')

    if options.prob_mni or options.labels_mni or options.all_mni:
        z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape)
        if options.prob_mni:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.mni     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.mni   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_mni or options.nobias_mni
                         or options.all_mni):
        bias = spatial.reslice(prm['bias'],
                               dat_affine,
                               mni_affine,
                               mni_shape,
                               interpolation=3,
                               prefilter=False,
                               bound='dct2')

        if options.bias_mni or options.all_mni:
            fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.mni     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.mni.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_mni or options.all_mni:
            nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape)
            nobias *= bias
            fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.mni   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.mni.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias

    need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp
                  or options.bias_wrp or options.nobias_wrp or options.all_mni
                  or options.all_wrp)
    need_iwarp = need_iwarp and options.warp
    if not need_iwarp:
        return

    iwarp = spatial.grid_inv(prm['warp'], type='disp')
    iwarp = iwarp.movedim(-1, 0)
    iwarp = spatial.reslice(iwarp,
                            dat_affine,
                            mni_affine,
                            mni_shape,
                            interpolation=2,
                            bound='dft',
                            extrapolate=True)
    iwarp = iwarp.movedim(0, -1)
    iaff = mni_affine.inverse() @ dat_affine
    iwarp = linalg.matvec(iaff[:3, :3], iwarp)

    if (options.warp_mni or options.all_mni) and options.warp:
        fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.mni     ->', fname)
        io.savef(iwarp,
                 fname,
                 like=ref_native,
                 affine=mni_affine,
                 dtype='float32')

    # --- Warped space -------------------------------------------------
    iwarp = spatial.add_identity_grid_(iwarp)
    iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp)

    if options.prob_wrp or options.labels_wrp or options.all_wrp:
        z_mni = spatial.grid_pull(z, iwarp)
        if options.prob_mni or options.all_wrp:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.wrp     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni or options.all_wrp:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.wrp   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_wrp or options.nobias_wrp
                         or options.all_wrp):
        bias = spatial.grid_pull(prm['bias'],
                                 iwarp,
                                 interpolation=3,
                                 prefilter=False,
                                 bound='dct2')
        if options.bias_wrp or options.all_wrp:
            fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.wrp     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.wrp.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_wrp or options.all_wrp:
            nobias = spatial.grid_pull(dat, iwarp)
            nobias *= bias
            fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.wrp   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.wrp.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias
Пример #20
0
    def forward(self, image, **overload):
        """

        Parameters
        ----------
        image : (batch, channel, *shape) tensor
            Input image
        overload : dict
            All parameters defined at build time can be overridden at call time

        Returns
        -------
        warped : (batch, channel, *shape) tensor
            Deformed image
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """

        image = torch.as_tensor(image)
        dim = image.dim() - 2
        batch, channel, *shape = image.shape
        info = {'dtype': image.dtype, 'device': image.device}

        # get arguments
        opt_grid = {
            'dim': dim,
            'shape': shape,
            'amplitude': overload.get('vel_amplitude', self.grid.amplitude),
            'fwhm': overload.get('vel_fwhm', self.grid.fwhm),
            'bound': overload.get('vel_bound', self.grid.bound),
            'interpolation': overload.get('interpolation',
                                          self.grid.interpolation),
            'dtype': overload.get('dtype', self.grid.dtype),
            'device': overload.get('device', self.grid.device),
        }
        opt_affine = {
            'dim': dim,
            'translation': overload.get('translation',
                                        self.affine.translation),
            'rotation': overload.get('rotation', self.affine.rotation),
            'zoom': overload.get('zoom', self.affine.zoom),
            'shear': overload.get('shear', self.affine.shear),
            'dtype': overload.get('dtype', self.affine.dtype),
            'device': overload.get('device', self.affine.device),
        }
        opt_pull = {
            'bound':
            overload.get('image_bound', self.pull.bound),
            'interpolation':
            overload.get('interpolation', self.pull.interpolation),
        }

        grid = self.grid(batch, **opt_grid)
        aff = self.affine(batch, **opt_affine)

        # shift center of rotation
        aff_shift = torch.cat(
            (torch.eye(dim, **info),
             -torch.as_tensor(opt_grid['shape'], **info)[:, None] / 2),
            dim=1)
        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = matvec(lin, grid) + off

        # pull
        warped = self.pull(image, grid, **opt_pull)

        return warped, grid