Exemplo n.º 1
0
def _warp_image1(image,
                 target,
                 shape=None,
                 affine=None,
                 nonlin=None,
                 backward=False,
                 reslice=False):
    """Returns the warped image, with channel dimension last"""
    # build transform
    aff_right = target
    aff_left = spatial.affine_inv(image.affine)
    aff = None
    if affine:
        # exp = affine.iexp if backward else affine.exp
        exp = affine.exp
        aff = exp(recompute=False, cache_result=True)
        if backward:
            aff = spatial.affine_inv(aff)
    if nonlin:
        if affine:
            if affine.position[0].lower() in ('ms' if backward else 'fs'):
                aff_right = spatial.affine_matmul(aff, aff_right)
            if affine.position[0].lower() in ('fs' if backward else 'ms'):
                aff_left = spatial.affine_matmul(aff_left, aff)
        exp = nonlin.iexp if backward else nonlin.exp
        phi = exp(recompute=False, cache_result=True)
        aff_left = spatial.affine_matmul(aff_left, nonlin.affine)
        aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right)
        if _almost_identity(aff_right) and nonlin.shape == shape:
            phi = nonlin.add_identity(phi)
        else:
            tmp = spatial.affine_grid(aff_right, shape)
            phi = regutils.smart_pull_grid(phi, tmp).add_(tmp)
            del tmp
        if not _almost_identity(aff_left):
            phi = spatial.affine_matvec(aff_left, phi)
    else:
        # no nonlin: single affine even if position == 'symmetric'
        if reslice:
            aff = spatial.affine_matmul(aff, aff_right)
            aff = spatial.affine_matmul(aff_left, aff)
            phi = spatial.affine_grid(aff, shape)
        else:
            phi = None

    # warp image
    if phi is not None:
        warped = image.pull(phi)
    else:
        warped = image.dat

    # write to disk
    if len(warped) == 1:
        warped = warped[0]
    else:
        warped = utils.movedim(warped, 0, -1)
    return warped
Exemplo n.º 2
0
    def forward(self, affine, **overload):
        """

        Parameters
        ----------
        affine : (batch, ndim[+1], ndim+1) tensor
            Affine matrix
        overload : dict
            All parameters of the module can be overridden at call time.

        Returns
        -------
        grid : (batch, *shape, ndim) tensor
            Dense transformation grid

        """

        nb_dim = affine.shape[-1] - 1
        info = {'dtype': affine.dtype, 'device': affine.device}
        shape = make_list(overload.get('shape', self.shape), nb_dim)
        shift = overload.get('shift', self.shift)

        if shift:
            affine_shift = torch.cat((
                torch.eye(nb_dim, **info),
                -torch.as_tensor(shape, **info)[:, None]/2),
                dim=1)
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

        grid = spatial.affine_grid(affine, shape)
        return grid
Exemplo n.º 3
0
    def forward(self, affine, shape=None):
        """

        Parameters
        ----------
        affine : (batch, ndim[+1], ndim+1) tensor
            Affine matrix
        shape : sequence[int], default=self.shape

        Returns
        -------
        grid : (batch, *shape, ndim) tensor
            Dense transformation grid

        """

        nb_dim = affine.shape[-1] - 1
        backend = {'dtype': affine.dtype, 'device': affine.device}
        shape = shape or self.shape

        if self.shift:
            affine_shift = torch.eye(nb_dim + 1, **backend)
            affine_shift[:nb_dim, -1] = torch.as_tensor(shape, **backend)
            affine_shift[:nb_dim, -1].sub(1).div(2).neg()
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

        grid = spatial.affine_grid(affine, shape)
        return grid
Exemplo n.º 4
0
 def pull(q, grid):
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
     grid = spatial.affine_matvec(aff[expd], grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Exemplo n.º 5
0
 def pull(q, vel):
     grid = spatial.exp(vel)
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     grid = spatial.affine_matvec(aff, grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Exemplo n.º 6
0
    def set_fdata(self, affine):
        affine = torch.as_tensor(affine)
        backend = dict(dtype=affine.dtype, device=affine.device)
        if affine.shape[-2:] != (4, 4):
            raise ValueError('Expected a batch of 4x4 matrix')

        # we may need to convert from RAS space to a weird space
        afftype = self.type()[0]
        if afftype != 'ras':
            src, _ = self.source_space(afftype, 'ras', **backend)
            dst, _ = self.destination_space('ras', afftype, **backend)
            if src is not None and dst is not None:
                affine = affine_matmul(dst, affine_matmul(affine, src))

        affine = np.asarray(affine).reshape([-1, 4, 4])
        self._struct.affine = affine
        self._struct.nxform = affine.shape[0]
        return self
Exemplo n.º 7
0
 def slice_to(self, stack, cache_result=False, recompute=True):
     aff = self.exp(cache_result=cache_result, recompute=recompute)
     if recompute or not hasattr(self, '_sliced'):
         aff = spatial.affine_matmul(aff, self.affine)
         aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout)
         aff = spatial.affine_lmdiv(aff_reorient, aff)
         aff = spatial.affine_grid(aff, self.shape)
         sliced = spatial.grid_pull(self.dat, aff, bound=self.bound,
                                    extrapolate=self.extrapolate)
         fwhm = [0] * self.dim
         fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1]
         sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound)
         slices = []
         for stack_slice in stack.slices:
             aff = spatial.affine_matmul(stack.affine, )
             aff = spatial.affine_lmdiv(aff_reorient, )
     if cache_result:
         self._sliced = sliced
     return sliced
Exemplo n.º 8
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch shape.

        Other Parameters
        ----------------
        shape : sequence[int], optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """
        shape = overload.get('shape', self.grid.velocity.field.shape)
        dtype = overload.get('dtype', self.grid.velocity.field.dtype)
        device = overload.get('device', self.grid.velocity.field.device)
        backend = dict(dtype=dtype, device=device)

        if self.grid.velocity.field.amplitude == 0:
            grid = identity_grid(shape, **backend)
        else:
            grid = self.grid(batch, shape=shape, **backend)
        dtype = grid.dtype
        device = grid.device
        backend = dict(dtype=dtype, device=device)

        shape = grid.shape[1:-1]
        dim = len(shape)
        aff = self.affine(batch, dim=dim, **backend)

        # shift center of rotation
        aff_shift = torch.cat((
            torch.eye(dim, **backend),
            torch.as_tensor(shape, **backend)[:, None].sub_(1).div_(-2)),
            dim=1)
        aff_shift = as_euclidean(aff_shift)

        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = utils.unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = linalg.matvec(lin, grid) + off

        return grid
Exemplo n.º 9
0
    def fdata(self, dtype=None, device=None, numpy=False):
        dtype = dtype or torch.get_default_dtype()
        backend = dict(dtype=dtype, device=device)
        affine = self.data(**backend)
        if affine is None:
            return None

        # we may need to convert from a weird space to RAS space
        afftype = self.type()[0]
        if afftype != 'ras':
            src, _ = self.source_space('ras', afftype, **backend)
            dst, _ = self.destination_space(afftype, 'ras', **backend)
            if src is not None and dst is not None:
                affine = affine_matmul(dst, affine_matmul(affine, src))

        affine = cast(affine, dtype)
        if numpy:
            return np.asarray(affine)
        else:
            return torch.as_tensor(affine, dtype=dtype, device=device)
Exemplo n.º 10
0
    def exp(self, velocity, affine=None, displacement=False):
        """Generate a deformation grid from tangent parameters.

        Parameters
        ----------
        velocity : (batch, *spatial, nb_dim)
            Stationary velocity field
        affine : (batch, nb_prm)
            Affine parameters
        displacement : bool, default=False
            Return a displacement field (voxel to shift) rather than
            a transformation field (voxel to voxel).

        Returns
        -------
        grid : (batch, *spatial, nb_dim)
            Deformation grid (transformation or displacment).

        """
        info = {'dtype': velocity.dtype, 'device': velocity.device}

        # generate grid
        shape = velocity.shape[1:-1]
        velocity_small = self.resize(velocity, type='displacement')
        grid = self.velexp(velocity_small)
        grid = self.resize(grid, shape=shape, type='grid')

        if affine is not None:
            # exponentiate
            affine_prm = affine
            affine = []
            for prm in affine_prm:
                affine.append(self.affexp(prm))
            affine = torch.stack(affine, dim=0)

            # shift center of rotation
            affine_shift = torch.cat(
                (torch.eye(self.dim, **info),
                 -torch.as_tensor(shape, **info)[:, None] / 2),
                dim=1)
            affine = spatial.affine_matmul(affine, affine_shift)
            affine = spatial.affine_lmdiv(affine_shift, affine)

            # compose
            affine = unsqueeze(affine, dim=-3, ndim=self.dim)
            lin = affine[..., :self.dim, :self.dim]
            off = affine[..., :self.dim, -1]
            grid = matvec(lin, grid) + off

        if displacement:
            grid = grid - spatial.identity_grid(grid.shape[1:-1], **info)

        return grid
Exemplo n.º 11
0
def affine_to_fs(affine, shape, source='voxel', dest='ras'):
    """Convert an affine matrix into FS parameters (vx/cosine/shift)

    Parameters
    ----------
    affine : (4, 4) tensor
    shape : (int, int, int)
    source : {'voxel', 'physical', 'ras'}, default='voxel'
    dest : {'voxel', 'physical', 'ras'}, default='ras'

    Returns
    -------
    voxel_size : (float, float, float)
    x : (float, float, float)
    y : (float, float, float)
    z: (float, float, float)
    c : (float, float, float)

    """

    affine = torch.as_tensor(affine)
    backend = dict(dtype=affine.dtype, device=affine.device)
    vx = get_voxel_size(affine)
    shape = torch.as_tensor(shape, **backend)
    source = source.lower()[0]
    dest = dest.lower()[0]

    shift = shape / 2.
    shift = -shift * vx
    vox2phys = Orientation(shift, vx).affine()

    if (source, dest) in (('v', 'p'), ('p', 'v')):
        phys2ras = torch.eye(4, **backend)

    elif (source, dest) in (('v', 'r'), ('r', 'v')):
        if source == 'r':
            affine = affine_inv(affine)
        phys2vox = affine_inv(vox2phys)
        phys2ras = affine_matmul(affine, phys2vox)

    else:
        assert (source, dest) in (('p', 'r'), ('r', 'p'))
        if source == 'r':
            affine = affine_inv(affine)
        phys2ras = affine

    phys2ras = HomogeneousAffineMatrix(phys2ras)
    return (vx.tolist(), phys2ras.xras().tolist(), phys2ras.yras().tolist(),
            phys2ras.zras().tolist(), phys2ras.cras().tolist())
Exemplo n.º 12
0
def collapse_transforms(options):
    """Pre-invert affines and combine sequential affines"""
    trfs = []
    last_trf = None
    for trf in options.transformations:
        if isinstance(trf, Linear):
            if trf.inv:
                trf.affine = spatial.affine_inv(trf.affine)
                trf.inv = False
            if isinstance(last_trf, Linear):
                last_trf.affine = spatial.affine_matmul(
                    last_trf.affine, trf.affine)
            else:
                last_trf = trf
        else:
            if isinstance(last_trf, Linear):
                trfs.append(last_trf)
                last_trf = None
            trfs.append(trf)
    if isinstance(last_trf, Linear):
        trfs.append(last_trf)
    options.transformations = trfs
Exemplo n.º 13
0
def fs_to_affine(shape,
                 voxel_size=1.,
                 x=None,
                 y=None,
                 z=None,
                 c=0.,
                 source='voxel',
                 dest='ras'):
    """Transform FreeSurfer orientation parameters into an affine matrix.

    The returned matrix is effectively a "<source> to <dest>" transform.

    Parameters
    ----------
    shape : sequence of int
    voxel_size : [sequence of] float, default=1
    x : [sequence of] float, default=[1, 0, 0]
    y: [sequence of] float, default=[0, 1, 0]
    z: [sequence of] float, default=[0, 0, 1]
    c: [sequence of] float, default=0
    source : {'voxel', 'physical', 'ras'}, default='voxel'
    dest : {'voxel', 'physical', 'ras'}, default='ras'

    Returns
    -------
    affine : (4, 4) tensor

    """
    dim = len(shape)
    shape, voxel_size, x, y, z, c \
        = utils.to_max_backend(shape, voxel_size, x, y, z, c)
    backend = dict(dtype=shape.dtype, device=shape.device)
    voxel_size = utils.make_vector(voxel_size, dim)
    if x is None:
        x = [1, 0, 0]
    if y is None:
        y = [0, 1, 0]
    if z is None:
        z = [0, 0, 1]
    x = utils.make_vector(x, dim)
    y = utils.make_vector(y, dim)
    z = utils.make_vector(z, dim)
    c = utils.make_vector(c, dim)

    shift = shape / 2.
    shift = -shift * voxel_size
    vox2phys = Orientation(shift, voxel_size).affine()
    phys2ras = XYZC(x, y, z, c).affine()

    affines = []
    if source.lower().startswith('vox'):
        affines.append(vox2phys)
        middle_space = 'phys'
    elif source.lower().startswith('phys'):
        if dest.lower().startswith('vox'):
            affines.append(affine_inv(vox2phys))
            middle_space = 'vox'
        else:
            affines.append(phys2ras)
            middle_space = 'ras'
    elif source.lower() == 'ras':
        affines.append(affine_inv(phys2ras))
        middle_space = 'phys'
    else:
        # We need a matrix to switch orientations
        affines.append(layout_matrix(source, **backend))
        middle_space = 'ras'

    if dest.lower().startswith('phys'):
        if middle_space == 'vox':
            affines.append(vox2phys)
        elif middle_space == 'ras':
            affines.append(affine_inv(phys2ras))
    elif dest.lower().startswith('vox'):
        if middle_space == 'phys':
            affines.append(affine_inv(vox2phys))
        elif middle_space == 'ras':
            affines.append(affine_inv(phys2ras))
            affines.append(affine_inv(vox2phys))
    elif dest.lower().startswith('ras'):
        if middle_space == 'phys':
            affines.append(phys2ras)
        elif middle_space.lower().startswith('vox'):
            affines.append(vox2phys)
            affines.append(phys2ras)
    else:
        if middle_space == 'phys':
            affines.append(affine_inv(phys2ras))
        elif middle_space == 'vox':
            affines.append(vox2phys)
            affines.append(phys2ras)
        layout = layout_matrix(dest, **backend)
        affines.append(affine_inv(layout))

    affine, *affines = affines
    for aff in affines:
        affine = affine_matmul(aff, affine)
    return affine
Exemplo n.º 14
0
    def forward(self, grid, **overload):
        """

        Parameters
        ----------
        grid : (N, *spatial, dim)
            Displacement grid
        overload : dict

        Returns
        -------
        aff : (N, dim+1, dim+1)
            Affine matrix that is closest to grid in the least square sense

        """
        shift = overload.get('shift', self.shift)
        grid = torch.as_tensor(grid)
        info = dict(dtype=grid.dtype, device=grid.device)
        nb_dim = grid.shape[-1]
        shape = grid.shape[1:-1]

        if shift:
            affine_shift = torch.cat((torch.eye(
                nb_dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2),
                                     dim=1)
            affine_shift = spatial.as_euclidean(affine_shift)

        # the forward model is:
        #   phi(x) = M\A*M*x
        # where phi is a *transformation* field, M is the shift matrix
        # and A is the affine matrix.
        # We can decompose phi(x) = x + d(x), where d is a *displacement*
        # field, yielding:
        #   d(x) = M\A*M*x - x = (M\A*M - I)*x := B*x
        # If we write `d(x)` and `x` as large vox*(dim+1) matrices `D`
        # and `G`, we have:
        #   D = G*B'
        # Therefore, the least squares B is obtained as:
        #   B' = inv(G'*G) * (G'*D)
        # Then, A is
        #   A = M*(B + I)/M
        #
        # Finally, we project the affine matrix to its tangent space:
        #   prm[k] = <log(A), B[k]>
        # were <X,Y> = trace(X'*Y) is the Frobenius inner product.

        def igg(identity):
            # Compute inv(g*g'), where g has homogeneous coordinates.
            #   Instead of appending ones, we compute each element of
            #   the block matrix ourselves:
            #       [[g'*g,   g'*1],
            #        [1'*g,   1'*1]]
            #    where 1'*1 = N, the number of voxels.
            g = identity.reshape([identity.shape[0], -1, nb_dim])
            nb_vox = torch.as_tensor([[[g.shape[1]]]], **info)
            sumg = g.sum(dim=1, keepdim=True)
            gg = torch.matmul(g.transpose(-1, -2), g)
            gg = torch.cat((gg, sumg), dim=1)
            sumg = sumg.transpose(-1, -2)
            sumg = torch.cat((sumg, nb_vox), dim=1)
            gg = torch.cat((gg, sumg), dim=2)
            return gg.inverse()

        def gd(identity, disp):
            # compute g'*d, where g and d have homogeneous coordinates.
            #       [[g'*d,   g'*1],
            #        [1'*d,   1'*1]]
            g = identity.reshape([identity.shape[0], -1, nb_dim])
            d = disp.reshape([disp.shape[0], -1, nb_dim])
            nb_vox = torch.as_tensor([[[g.shape[1]]]], **info)
            sumg = g.sum(dim=1, keepdim=True)
            sumd = d.sum(dim=1, keepdim=True)
            gd = torch.matmul(g.transpose(-1, -2), d)
            gd = torch.cat((gd, sumd), dim=1)
            sumg = sumg.transpose(-1, -2)
            sumg = torch.cat((sumg, nb_vox), dim=1)
            sumg = sumg.expand([d.shape[0], sumg.shape[1], sumg.shape[2]])
            gd = torch.cat((gd, sumg), dim=2)
            return gd

        def eye(d):
            x = torch.eye(d, **info)
            z = x.new_zeros([1, d], **info)
            x = torch.cat((x, z), dim=0)
            z = x.new_zeros([d + 1, 1], **info)
            x = torch.cat((x, z), dim=1)
            return x

        identity = spatial.identity_grid(shape, **info)[None, ...]
        affine = torch.matmul(igg(identity), gd(identity, grid))
        affine = affine.transpose(-1, -2) + eye(nb_dim)
        affine = affine[..., :-1, :]
        if shift:
            affine = spatial.as_euclidean(affine)
            affine = spatial.affine_matmul(affine_shift, affine)
            affine = spatial.as_euclidean(affine)
            affine = spatial.affine_rmdiv(affine, affine_shift)
        affine = spatial.affine_make_square(affine)

        return affine
Exemplo n.º 15
0
    def forward():
        """Forward pass up to the loss"""

        loss = 0

        # affine matrix
        A = None
        for trf in options.transformations:
            trf.update()
            if isinstance(trf, struct.Linear):
                q = trf.optdat.to(**backend)
                # print(q.tolist())
                B = trf.basis.to(**backend)
                A = linalg.expm(q, B)
                if torch.is_tensor(trf.shift):
                    # include shift
                    shift = trf.shift.to(**backend)
                    eye = torch.eye(options.dim, **backend)
                    A = A.clone()  # needed because expm is a custom autograd.Function
                    A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift)
                for loss1 in trf.losses:
                    loss += loss1.call(q)
                break

        # non-linear displacement field
        d = None
        d_aff = None
        for trf in options.transformations:
            if not trf.isfree():
                continue
            if isinstance(trf, struct.FFD):
                d = trf.dat.to(**backend)
                d = ffd_exp(d, trf.shape, returns='disp')
                for loss1 in trf.losses:
                    loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break
            elif isinstance(trf, struct.Diffeo):
                d = trf.dat.to(**backend)
                if not trf.smalldef:
                    # penalty on velocity fields
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d = spatial.exp(d[None], displacement=True)[0]
                if trf.smalldef:
                    # penalty on exponentiated transform
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break

        # loop over image pairs
        for match in options.losses:
            if not match.fixed:
                continue
            nb_levels = len(match.fixed.dat)
            prm = dict(interpolation=match.moving.interpolation,
                       bound=match.moving.bound,
                       extrapolate=match.moving.extrapolate)
            # loop over pyramid levels
            for moving, fixed in zip(match.moving.dat, match.fixed.dat):
                moving_dat, moving_aff = moving
                fixed_dat, fixed_aff = fixed

                moving_dat = moving_dat.to(**backend)
                moving_aff = moving_aff.to(**backend)
                fixed_dat = fixed_dat.to(**backend)
                fixed_aff = fixed_aff.to(**backend)

                # affine-corrected moving space
                if A is not None:
                    Ms = affine_matmul(A, moving_aff)
                else:
                    Ms = moving_aff

                if d is not None:
                    # fixed to param
                    Mt = affine_lmdiv(d_aff, fixed_aff)
                    if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]):
                        g = smalldef(d)
                    else:
                        g = affine_grid(Mt, fixed_dat.shape[1:])
                        g = g + pull_grid(d, g)
                    # param to moving
                    Ms = affine_lmdiv(Ms, d_aff)
                    g = affine_matvec(Ms, g)
                else:
                    # fixed to moving
                    Mt = fixed_aff
                    Ms = affine_lmdiv(Ms, Mt)
                    g = affine_grid(Ms, fixed_dat.shape[1:])

                # pull moving image
                warped_dat = pull(moving_dat, g, **prm)
                loss += match.call(warped_dat, fixed_dat) / float(nb_levels)

                # import matplotlib.pyplot as plt
                # plt.subplot(1, 2, 1)
                # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.subplot(1, 2, 2)
                # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.show()

        return loss
Exemplo n.º 16
0
def _warp_image(option,
                affine=None,
                nonlin=None,
                dim=None,
                device=None,
                odir=None):
    """Warp and save the moving and fixed images from a loss object"""

    if not (option.mov.output or option.mov.resliced or option.fix.output
            or option.fix.resliced):
        return

    fix, fix_affine = _map_image(option.fix.files, dim=dim)
    mov, mov_affine = _map_image(option.mov.files, dim=dim)
    fix_affine = fix_affine.float()
    mov_affine = mov_affine.float()
    dim = dim or (fix.dim - 1)

    if option.fix.world:  # overwrite orientation matrix
        fix_affine = io.transforms.map(option.fix.world).fdata().squeeze()
    for transform in (option.fix.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        fix_affine = spatial.affine_lmdiv(transform, fix_affine)

    if option.mov.world:  # overwrite orientation matrix
        mov_affine = io.transforms.map(option.mov.world).fdata().squeeze()
    for transform in (option.mov.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        mov_affine = spatial.affine_lmdiv(transform, mov_affine)

    # moving
    if option.mov.output or option.mov.resliced:
        ifname = option.mov.files[0]
        idir, base, ext = py.fileparts(ifname)
        odir_mov = odir or idir or '.'

        image = objects.Image(mov.fdata(rand=True, device=device),
                              dim=dim,
                              affine=mov_affine,
                              bound=option.mov.bound,
                              extrapolate=option.mov.extrapolate)

        if option.mov.output:
            target_affine = mov_affine
            target_shape = image.shape
            if affine and affine.position[0].lower() in 'ms':
                aff = affine.exp(recompute=False, cache_result=True)
                target_affine = spatial.affine_lmdiv(aff, target_affine)

            fname = option.mov.output.format(dir=odir_mov,
                                             base=base,
                                             sep=os.path.sep,
                                             ext=ext)
            print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

        if option.mov.resliced:
            target_affine = fix_affine
            target_shape = fix.shape[1:]

            fname = option.mov.resliced.format(dir=odir_mov,
                                               base=base,
                                               sep=os.path.sep,
                                               ext=ext)
            print(f'Full reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  reslice=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

    # fixed
    if option.fix.output or option.fix.resliced:
        ifname = option.fix.files[0]
        idir, base, ext = py.fileparts(ifname)
        odir_fix = odir or idir or '.'

        image = objects.Image(fix.fdata(rand=True, device=device),
                              dim=dim,
                              affine=fix_affine,
                              bound=option.fix.bound,
                              extrapolate=option.fix.extrapolate)

        if option.fix.output:
            target_affine = fix_affine
            target_shape = image.shape
            if affine and affine.position[0].lower() in 'fs':
                aff = affine.exp(recompute=False, cache_result=True)
                target_affine = spatial.affine_matmul(aff, target_affine)

            fname = option.fix.output.format(dir=odir_fix,
                                             base=base,
                                             sep=os.path.sep,
                                             ext=ext)
            print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  backward=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

        if option.fix.resliced:
            target_affine = mov_affine
            target_shape = mov.shape[1:]

            fname = option.fix.resliced.format(dir=odir_fix,
                                               base=base,
                                               sep=os.path.sep,
                                               ext=ext)
            print(f'Full reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  backward=True,
                                  reslice=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped
Exemplo n.º 17
0
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None,
           interpolation=1, bound='dct2', extrapolate=False, device=None,
           verbose=True):
    """Apply the linear and non-linear components of the transform and
    reslice the image to the target space.

    Notes
    -----
    .. The shape and general orientation of the moving image is kept untouched.
    .. The linear transform is composed with the original orientation matrix.
    .. The non-linear component is "wrapped" in the input space, where it
       is applied.
    .. This function writes a new file (it does not modify the input files
       in place).

    Parameters
    ----------
    moving : ImageFile
        An object describing a moving image.
    fname : list of str
        Output filename for each input file of the moving image
        (since images can be encoded over multiple volumes)
    like : ImageFile
        An object describing the target space
    inv : bool, default=False
        True if we are warping the fixed image to the moving space.
        In the case, `moving` should be a `FixedImageFile` and `like` a
        `MovingImageFile`. Else it should be a `MovingImageFile` and `'like`
        a `FixedImageFile`.
    lin : (4, 4) tensor, optional
        Linear (or rather affine) transformation
    nonlin : dict, optional
        Non-linear displacement field, with keys:
        disp : (..., 3) tensor
            Displacement field (in voxels)
        affine : (4, 4) tensor
            Orientation matrix of the displacement field
    interpolation : int, default=1
    bound : str, default='dct2'
    extrapolate : bool, default = False
    device : default='cpu'

    """
    nonlin = nonlin or dict(disp=None, affine=None)
    prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate)

    moving_affine = moving.affine.to(device)
    fixed_affine = like.affine.to(device)

    if inv:
        # affine-corrected fixed space
        if lin is not None:
            fix2lin = affine_matmul(lin, fixed_affine)
        else:
            fix2lin = fixed_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param to moving
            nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(moving_affine, fix2lin)
            g = affine_grid(g, like.shape)

    else:
        # affine-corrected moving space
        if lin is not None:
            mov2nlin = affine_matmul(lin, moving_affine)
        else:
            mov2nlin = moving_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param voxels to moving voxels (warps moving to fixed)
            nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(mov2nlin, fixed_affine)
            g = affine_grid(g, like.shape)

    if moving.type == 'labels':
        prm['interpolation'] = 0
    for file, ofname in zip(moving.files, fname):
        if verbose:
            print(f'Resliced:   {file.fname}\n'
                  f'         -> {ofname}')
        dat = io.volumes.loadf(file.fname, rand=True, device=device)
        dat = dat.reshape([*file.shape, file.channels])
        if g is not None:
            dat = utils.movedim(dat, -1, 0)
            dat = pull(dat, g, **prm)
            dat = utils.movedim(dat, 0, -1)
        io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
Exemplo n.º 18
0
    def __call__(self, logaff, grad=False, hess=False, in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            logaff.requires_grad_()
            if logaff.grad is not None:
                logaff.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(logaff, grad, in_line_search=in_line_search)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(logaff, grad, in_line_search=in_line_search)

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                   and (self.n_iter - 1) % logplot == 0

        # jitter
        # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True,
        #                             **utils.backend(self.fixed))
        # fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        # del idj
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape, **utils.backend(logaff))
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        llx = self.loss.loss(warped, fixed)
        del warped

        # print objective
        lll = llx
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [lll]
        if grad is not False:
            lll.backward()
            grad = logaff.grad.clone()
            out.append(grad)
        logaff.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]
Exemplo n.º 19
0
    def forward(self, image, **overload):
        """

        Parameters
        ----------
        image : (batch, channel, *shape) tensor
            Input image
        overload : dict
            All parameters defined at build time can be overridden at call time

        Returns
        -------
        warped : (batch, channel, *shape) tensor
            Deformed image
        grid : (batch, *shape, 3) tensor
            Resampling grid

        """

        image = torch.as_tensor(image)
        dim = image.dim() - 2
        batch, channel, *shape = image.shape
        info = {'dtype': image.dtype, 'device': image.device}

        # get arguments
        opt_grid = {
            'dim': dim,
            'shape': shape,
            'amplitude': overload.get('vel_amplitude', self.grid.amplitude),
            'fwhm': overload.get('vel_fwhm', self.grid.fwhm),
            'bound': overload.get('vel_bound', self.grid.bound),
            'interpolation': overload.get('interpolation',
                                          self.grid.interpolation),
            'dtype': overload.get('dtype', self.grid.dtype),
            'device': overload.get('device', self.grid.device),
        }
        opt_affine = {
            'dim': dim,
            'translation': overload.get('translation',
                                        self.affine.translation),
            'rotation': overload.get('rotation', self.affine.rotation),
            'zoom': overload.get('zoom', self.affine.zoom),
            'shear': overload.get('shear', self.affine.shear),
            'dtype': overload.get('dtype', self.affine.dtype),
            'device': overload.get('device', self.affine.device),
        }
        opt_pull = {
            'bound':
            overload.get('image_bound', self.pull.bound),
            'interpolation':
            overload.get('interpolation', self.pull.interpolation),
        }

        grid = self.grid(batch, **opt_grid)
        aff = self.affine(batch, **opt_affine)

        # shift center of rotation
        aff_shift = torch.cat(
            (torch.eye(dim, **info),
             -torch.as_tensor(opt_grid['shape'], **info)[:, None] / 2),
            dim=1)
        aff = affine_matmul(aff, aff_shift)
        aff = affine_lmdiv(aff_shift, aff)

        # compose
        aff = unsqueeze(aff, dim=-3, ndim=dim)
        lin = aff[..., :dim, :dim]
        off = aff[..., :dim, -1]
        grid = matvec(lin, grid) + off

        # pull
        warped = self.pull(image, grid, **opt_pull)

        return warped, grid
Exemplo n.º 20
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Exemplo n.º 21
0
    def compose(self,
                orient_in,
                deformation,
                orient_mean,
                affine=None,
                orient_out=None,
                shape_out=None):
        """Composes a deformation defined in a mean space to an image space.

        Parameters
        ----------
        orient_in : (4, 4) tensor
            Orientation of the input image
        deformation : (*shape_mean, 3) tensor
            Random deformation
        orient_mean : (4, 4) tensor
            Orientation of the mean space (where the deformation is)
        affine : (4, 4) tensor, default=identity
            Random affine
        orient_out : (4, 4) tensor, default=orient_in
            Orientation of the output image
        shape_out : sequence[int], default=shape_mean
            Shape of the output image

        Returns
        -------
        grid : (*shape_out, 3)
            Voxel-to-voxel transform

        """
        if orient_out is None:
            orient_out = orient_in
        if shape_out is None:
            shape_out = deformation.shape[:-1]
        if affine is None:
            affine = torch.eye(4,
                               4,
                               device=orient_in.device,
                               dtype=orient_in.dtype)
        shape_mean = deformation.shape[:-1]

        orient_in, affine, deformation, orient_mean, orient_out \
            = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out)
        backend = utils.backend(deformation)
        eye = torch.eye(4, **backend)

        # Compose deformation on the right
        right_affine = spatial.affine_lmdiv(orient_mean, orient_out)
        if not (shape_mean == shape_out and right_affine.all_close(eye)):
            # the mean space and native space are not the same
            # we must compose the diffeo with a dense affine transform
            # we write the diffeo as an identity plus a displacement
            # (id + disp)(aff) = aff + disp(aff)
            # -------
            # to displacement
            deformation = deformation - spatial.identity_grid(
                deformation.shape[:-1], **backend)
            trf = spatial.affine_grid(right_affine, shape_out)
            deformation = spatial.grid_pull(utils.movedim(deformation, -1,
                                                          0)[None],
                                            trf[None],
                                            bound='dft',
                                            extrapolate=True)
            deformation = utils.movedim(deformation[0], 0, -1)
            trf = trf + deformation  # add displacement

        # Compose deformation on the left
        #   the output of the diffeo(right) are mean_space voxels
        #   we must compose on the left with `in\(aff(mean))`
        # -------
        left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in),
                                            affine)
        left_affine = spatial.affine_matmul(left_affine, orient_mean)
        trf = spatial.affine_matvec(left_affine, trf)

        return trf