Ejemplo n.º 1
0
 def build_from_target(target):
     """Compose all transformations, starting from the final orientation"""
     grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
     for trf in reversed(options.transformations):
         if isinstance(trf, Linear):
             grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
         else:
             mat = trf.affine.to(**backend)
             if trf.inv:
                 vx0 = spatial.voxel_size(mat)
                 vx1 = spatial.voxel_size(target.affine.to(**backend))
                 factor = vx0 / vx1
                 disp, mat = spatial.resize_grid(trf.dat[None],
                                                 factor,
                                                 affine=mat,
                                                 interpolation=trf.spline)
                 disp = spatial.grid_inv(disp[0], type='disp')
                 order = 1
             else:
                 disp = trf.dat
                 order = trf.spline
             imat = spatial.affine_inv(mat)
             grid = spatial.affine_matvec(imat, grid)
             grid += helpers.pull_grid(disp, grid, interpolation=order)
             grid = spatial.affine_matvec(mat, grid)
     return grid
Ejemplo n.º 2
0
 def pull(q, grid):
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
     grid = spatial.affine_matvec(aff[expd], grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Ejemplo n.º 3
0
 def pull(q, vel):
     grid = spatial.exp(vel)
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     grid = spatial.affine_matvec(aff, grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Ejemplo n.º 4
0
def _warp_image1(image,
                 target,
                 shape=None,
                 affine=None,
                 nonlin=None,
                 backward=False,
                 reslice=False):
    """Returns the warped image, with channel dimension last"""
    # build transform
    aff_right = target
    aff_left = spatial.affine_inv(image.affine)
    aff = None
    if affine:
        # exp = affine.iexp if backward else affine.exp
        exp = affine.exp
        aff = exp(recompute=False, cache_result=True)
        if backward:
            aff = spatial.affine_inv(aff)
    if nonlin:
        if affine:
            if affine.position[0].lower() in ('ms' if backward else 'fs'):
                aff_right = spatial.affine_matmul(aff, aff_right)
            if affine.position[0].lower() in ('fs' if backward else 'ms'):
                aff_left = spatial.affine_matmul(aff_left, aff)
        exp = nonlin.iexp if backward else nonlin.exp
        phi = exp(recompute=False, cache_result=True)
        aff_left = spatial.affine_matmul(aff_left, nonlin.affine)
        aff_right = spatial.affine_lmdiv(nonlin.affine, aff_right)
        if _almost_identity(aff_right) and nonlin.shape == shape:
            phi = nonlin.add_identity(phi)
        else:
            tmp = spatial.affine_grid(aff_right, shape)
            phi = regutils.smart_pull_grid(phi, tmp).add_(tmp)
            del tmp
        if not _almost_identity(aff_left):
            phi = spatial.affine_matvec(aff_left, phi)
    else:
        # no nonlin: single affine even if position == 'symmetric'
        if reslice:
            aff = spatial.affine_matmul(aff, aff_right)
            aff = spatial.affine_matmul(aff_left, aff)
            phi = spatial.affine_grid(aff, shape)
        else:
            phi = None

    # warp image
    if phi is not None:
        warped = image.pull(phi)
    else:
        warped = image.dat

    # write to disk
    if len(warped) == 1:
        warped = warped[0]
    else:
        warped = utils.movedim(warped, 0, -1)
    return warped
Ejemplo n.º 5
0
    def forward(self, source, target, source_affine=None, target_affine=None):
        """

        Parameters
        ----------
        source : (sX, sY, sZ) tensor or str
        target : (tX, tY, tZ) tensor or str
        source_affine : (4, 4) tensor, optional
        target_affine : (4, 4) tensor, optional

        Returns
        -------
        warped : (tX, tY, tZ) tensor
            Source warped to target
        velocity : (vX, vY, vZ, 3) tensor
            Stationary velocity field
        affine : (4, 4) tensor, optional
            Affine of the velocity space

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        source, source_affine, source_orig, source_affine_orig \
            = self.load(source, source_affine)
        target, target_affine, target_orig, target_affine_orig \
            = self.load(target, target_affine)
        source = spatial.reslice(source, source_affine, target_affine,
                                 target.shape)
        if self.verbose:
            print('done.', flush=True)
            print('Registering... ', end='', flush=True)
        source = source[None, None]
        target = target[None, None]
        warped, vel, grid = super().forward(source, target)
        if self.verbose:
            print('done.', flush=True)
        del source, target, warped
        vel = vel[0]
        grid = grid[0]
        grid -= spatial.identity_grid(grid.shape[:-1],
                                      dtype=grid.dtype,
                                      device=grid.device)
        right_affine = target_affine.inverse() @ target_affine_orig
        right_affine = spatial.affine_grid(right_affine, target_orig.shape)
        grid = spatial.grid_pull(utils.movedim(grid, -1, 0),
                                 right_affine,
                                 bound='nearest',
                                 extrapolate=True)
        grid = utils.movedim(grid, 0, -1).add_(right_affine)
        left_affine = source_affine_orig.inverse() @ target_affine
        grid = spatial.affine_matvec(left_affine, grid)
        warped = spatial.grid_pull(source_orig, grid)
        return warped, vel, target_affine
Ejemplo n.º 6
0
def voxelize_rois(rois, shape, roi_to_vox=None, device=None):
    """Create a volume of labels from a parametric ROI.

    Parameters
    ----------
    rois : dict
        Object returned by `read_asc`
    shape : sequence[int]
    roi_to_vox : (d+1, d+1) tensor

    Returns
    -------
    roi : (*shape) tensor[int]
    names : list[str]

    """
    out = torch.empty(shape, dtype=torch.long)
    grid = spatial.identity_grid(shape[:2], device=device)
    if roi_to_vox is not None:
        roi_to_vox = roi_to_vox.to(device=device)

    names = list(rois['regions'].keys())

    for l, (name, shapes) in enumerate(rois['regions'].items()):
        print(name)
        label = l + 1
        for i, shape in enumerate(shapes):
            print(i + 1, '/', len(shapes), end='\r')
            vertices = [[p['x'], p['y'], p['z']] for p in shape['points']]
            vertices = torch.as_tensor(vertices, device=device)
            if roi_to_vox is not None:
                vertices = spatial.affine_matvec(roi_to_vox, vertices)
            z = math.round(vertices[0, 2]).int().item()
            if not (0 <= z < out.shape[-1]):
                print('Contour not in FOV. Skipping it...')
                continue
            vertices = vertices[:, :2]
            faces = [(i, i + 1 if i + 1 < len(vertices) else 0)
                     for i in range(len(vertices))]

            mask = is_inside(grid, vertices, faces).cpu()
            out[..., z][mask] = label
        print('')

    return out, names
Ejemplo n.º 7
0
def _crop_to_param(aff0, aff, shape):
    dim = aff0.shape[-1] - 1
    shape = shape[:dim]
    layout0 = spatial.affine_to_layout(aff0)
    layout = spatial.affine_to_layout(aff)
    if (layout0 != layout).any():
        raise ValueError('Input and Ref do not have the same layout: '
                         f'{spatial.volume_layout_to_name(layout0)} vs '
                         f'{spatial.volume_layout_to_name(layout)}.')
    size = shape
    layout = None
    unit = 'vox'

    center = torch.as_tensor(shape, dtype=torch.float).sub_(1).mul_(0.5)
    like_aff = spatial.affine_lmdiv(aff0, aff)
    center = spatial.affine_matvec(like_aff, center)

    return size, center, unit, layout
Ejemplo n.º 8
0
    def propagate_grad(self, g, h, moving, phi, left=None, right=None, inv=False):
        """Convert derivatives wrt warped image in loss space to
        to derivatives wrt parameters
        parameters:
            g (tensor) : gradient wrt warped image
            h (tensor) : hessian wrt warped image
            moving (Image) : moving image
            phi (tensor) : dense (exponentiated) displacement field
            left (matrix) : left affine
            right (matrix) : right affine
            inv (bool) : whether we're in a backward symmetric pass
        returns:
            g (tensor) : pushed gradient
            h (tensor) : pushed hessian
            gmu (tensor) : rotated spatial gradients
        """
        if inv:
            g = g.neg_()

        # build bits of warp
        dim = phi.shape[-1]
        fixed_shape = g.shape[-dim:]
        moving_shape = moving.shape

        # differentiate wrt δ in: Left o Phi o (Id + δ) o Right
        # we'll then propagate them through Phi by scaling and squaring
        if right is not None:
            right = spatial.affine_grid(right, fixed_shape)
        g = regutils.smart_push(g, right, shape=self.shape)
        h = regutils.smart_push(h, right, shape=self.shape)
        del right

        phi_left = spatial.identity_grid(self.shape, **utils.backend(phi))
        phi_left += phi
        if left is not None:
            phi_left = spatial.affine_matvec(left, phi_left)
        mugrad = moving.pull_grad(phi_left, rotate=False)
        del phi_left

        mugrad = _rotate_grad(mugrad, left, phi)

        return g, h, mugrad
Ejemplo n.º 9
0
def orient(inp,
           layout=None,
           voxel_size=None,
           center=None,
           like=None,
           output=None):
    """Overwrite the orientation matrix

    Parameters
    ----------
    inp : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    layout : str or layout-like, default=None (= preserve)
        Target orientation.
    voxel_size : [sequence of] float, default=None (= preserve)
        Target voxel size.
    center : [sequence of] float, default=None (= preserve)
        World coordinate of the center of the field of view.
    like : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    output : str, optional
        Output filename.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.

    Returns
    -------
    output : str or (tuple, tensor)
        If the input is a path, the output path is returned.
        Else, the new shape and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        dim = f.affine.shape[-1] - 1
        inp = (f.shape[:dim], f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{layout}{ext}'
        dir, base, ext = py.fileparts(fname)

    like_is_file = isinstance(like, str) and like
    if like_is_file:
        f = io.volumes.map(like)
        dim = f.affine.shape[-1] - 1
        like = (f.shape[:dim], f.affine)

    shape, aff0 = inp
    dim = aff0.shape[-1] - 1
    if like:
        shape_like, aff_like = like
    else:
        shape_like, aff_like = (shape, aff0)

    if voxel_size in (None, 'like') or len(voxel_size) == 0:
        voxel_size = spatial.voxel_size(aff_like)
    elif voxel_size == 'self':
        voxel_size = spatial.voxel_size(aff0)
    voxel_size = utils.make_vector(voxel_size, dim)

    if not layout or layout == 'like':
        layout = spatial.affine_to_layout(aff_like)
    elif layout == 'self':
        layout = spatial.affine_to_layout(aff0)
    layout = spatial.volume_layout(layout)

    if center in (None, 'like') or len(voxel_size) == 0:
        center = torch.as_tensor(shape_like, dtype=torch.float) * 0.5
        center = spatial.affine_matvec(aff_like, center)
    elif center == 'self':
        center = torch.as_tensor(shape, dtype=torch.float) * 0.5
        center = spatial.affine_matvec(aff0, center)

    center = utils.make_vector(center, dim)

    aff = spatial.affine_default(shape,
                                 voxel_size=voxel_size,
                                 layout=layout,
                                 center=center,
                                 dtype=torch.double)

    if output:
        dat = io.volumes.load(fname, numpy=True)
        layout = spatial.volume_layout_to_name(layout)
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   layout=layout)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, affine=aff)

    if is_file:
        return output
    else:
        return shape, aff
Ejemplo n.º 10
0
def orient(inp, affine=None, layout=None, voxel_size=None, center=None,
           like=None, output=None, output_transform=None):
    """Overwrite the orientation matrix

    Parameters
    ----------
    inp : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    affine : {'self', 'like'} or (4, 4) tensor_like, default='like'
        Target affine matrix
    layout : {'self', 'like'} or layout-like, default='like'
        Target orientation.
    voxel_size : {'self', 'like'} or [sequence of] float, default='like'
        Target voxel size.
    center : {'self', 'like'} or [sequence of] float, default='like'
        World coordinate of the center of the field of view.
    like : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    output : str, optional
        Output filename.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.
    output_transform : str, optional
        Filename of output transform.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}_to_{layout}.lta', where `dir` and `base`
        are the directory and base name of the input file.

    Returns
    -------
    output : str or (tuple, tensor)
        If the input is a path, the output path is returned.
        Else, the new shape and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        dim = f.affine.shape[-1] - 1
        inp = (f.shape[:dim], f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{layout}{ext}'
        if output_transform is None:
            output_transform = '{dir}{sep}{base}_to_{layout}.lta'
        dir, base, ext = py.fileparts(fname)

    like_is_file = isinstance(like, str) and like
    if like_is_file:
        f = io.volumes.map(like)
        dim = f.affine.shape[-1] - 1
        like = (f.shape[:dim], f.affine)

    shape, aff0 = inp
    dim = aff0.shape[-1] - 1
    if like:
        shape_like, aff_like = like
    else:
        shape_like, aff_like = (shape, aff0)

    if voxel_size in (None, 'like') or len(voxel_size) == 0:
        voxel_size = spatial.voxel_size(aff_like)
    elif voxel_size == 'self':
        voxel_size = spatial.voxel_size(aff0)
    elif voxel_size == 'standard':
        voxel_size = 1.
    voxel_size = utils.make_vector(voxel_size, dim)

    if not layout or layout == 'like':
        layout = spatial.affine_to_layout(aff_like)
    elif layout == 'self':
        layout = spatial.affine_to_layout(aff0)
    elif layout == 'standard':
        layout = 'RAS'
    layout = spatial.volume_layout(layout)

    if center in (None, 'like') or len(center) == 0:
        center = (torch.as_tensor(shape_like, dtype=torch.float) - 1) * 0.5
        center = spatial.affine_matvec(aff_like, center)
    elif center == 'self':
        center = (torch.as_tensor(shape, dtype=torch.float) - 1) * 0.5
        center = spatial.affine_matvec(aff0, center)
    elif center == 'standard':
        center = 0.
    center = utils.make_vector(center, dim)

    if affine in (None, 'like') or len(affine) == 0:
        affine = aff_like
    elif affine == 'self':
        affine = aff0
    elif affine == 'standard':
        affine = torch.eye(dim+1, dim+1)
    affine = torch.as_tensor(affine, dtype=torch.float)
    if affine.numel() == dim*(dim+1):
        affine = spatial.affine_make_rect(affine.reshape(dim, dim+1))
    elif affine.numel() == (dim+1)**2:
        affine = affine.reshape(dim+1, dim+1)
    else:
        raise ValueError(f'Input affine should have {dim*(dim+1)} or '
                         f'{(dim+1)**2} element but got {affine.numel()}.')

    affine = spatial.affine_modify(affine, shape, voxel_size=voxel_size,
                                   layout=layout, center=center)
    affine = affine.double()

    if output:
        dat = io.volumes.load(fname, numpy=True)
        layout = spatial.volume_layout_to_name(layout)
        if is_file:
            output = output.format(dir=dir or '.', base=base, ext=ext,
                                   sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, like=fname, affine=affine)
        else:
            output = output.format(sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, affine=affine)

    if output_transform:
        transform = spatial.affine_rmdiv(affine, aff0)
        output_transform = output_transform.format(
            dir=dir or '.', base=base, sep=os.path.sep, layout=layout)
        io.transforms.savef(transform.cpu(), output_transform, type=2)

    if is_file:
        return output
    else:
        return shape, affine
Ejemplo n.º 11
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            trf.dat = spatial.exp(trf.dat[None],
                                  displacement=True,
                                  inverse=trf.inv)[0]
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.unit == 'mm':
                # convert mm displacement to vox displacement
                trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None])
                trf.dat = trf.dat[..., 0]
                trf.unit = 'vox'

    def build_from_target(target):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(target.affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.spline)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.spline
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    if options.target:
        # If target is provided, we build a dense transformation field
        grid = build_from_target(options.target)
        oaffine = options.target.affine
        if options.output_unit[0] == 'v':
            grid = spatial.affine_matvec(spatial.affine_inv(oaffine), grid)
            grid = grid - spatial.identity_grid(grid.shape[:-1],
                                                **utils.backend(grid))
        else:
            grid = grid - spatial.affine_grid(
                oaffine.to(**utils.backend(grid)), grid.shape[:-1])
        io.volumes.savef(grid,
                         options.output.format(ext='.nii.gz'),
                         affine=oaffine)
    else:
        if len(options.transformations) > 1:
            raise RuntimeError('Something weird happened: '
                               'multiple transforms and no target')
        io.transforms.savef(options.transformations[0].affine,
                            options.output.format(ext='.lta'))
Ejemplo n.º 12
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # 1) Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, struct.Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            trf.dat = spatial.exp(trf.dat[None],
                                  displacement=True,
                                  inverse=trf.inv)[0]
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, struct.Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]

    # 2) If the first transform is linear, compose it with the input
    #    orientation matrix
    if (options.transformations
            and isinstance(options.transformations[0], struct.Linear)):
        trf = options.transformations[0]
        for file in options.files:
            mat = file.affine.to(**backend)
            aff = trf.affine.to(**backend)
            file.affine = spatial.affine_lmdiv(aff, mat)
        options.transformations = options.transformations[1:]

    def build_from_target(target):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, struct.Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(target.affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.order)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.order
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    # 3) If target is provided, we can build most of the transform once
    #    and just multiply it with a input-wise affine matrix.
    if options.target:
        grid = build_from_target(options.target)
        oaffine = options.target.affine

    # 4) Loop across input files
    opt = dict(interpolation=options.interpolation,
               bound=options.bound,
               extrapolate=options.extrapolate)
    output = utils.make_list(options.output, len(options.files))
    for file, ofname in zip(options.files, output):
        ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
        print(f'Reslicing:   {file.fname}\n' f'          -> {ofname}')
        dat = io.volumes.loadf(file.fname, rand=True, **backend)
        dat = dat.reshape([*file.shape, file.channels])
        dat = utils.movedim(dat, -1, 0)

        if not options.target:
            grid = build_from_target(file)
            oaffine = file.affine
        mat = file.affine.to(**backend)
        imat = spatial.affine_inv(mat)
        dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt)
        dat = utils.movedim(dat, 0, -1)

        io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
Ejemplo n.º 13
0
def get_oriented_slice(image, dim=-1, index=None, affine=None,
                       space=None, bbox=None, interpolation=1,
                       transpose_sagittal=False, return_index=False,
                       return_mat=False):
    """Sample a slice in a RAS system

    Parameters
    ----------
    image : (..., *shape3)
    dim : int, default=-1
        Index of spatial dimension to sample in the visualization space
        If RAS: -1 = axial / -2 = coronal / -3 = sagittal
    index : int, default=shape//2
        Coordinate (in voxel) of the slice to extract
    affine : (4, 4) tensor, optional
        Orientation matrix of the image
    space : (4, 4) tensor, optional
        Orientation matrix of the visualisation space.
        Default: RAS with minimum voxel size of all inputs.
    bbox : (2, D) tensor_like, optional
        Bounding box: min and max coordinates (in millimetric
        visualisation space). Default: bounding box of the input image.
    interpolation : {0, 1}, default=1
        Interpolation order.

    Returns
    -------
    slice : (..., *shape2) tensor
        Slice in the visualisation space.

    """
    # preproc dim
    if isinstance(dim, str):
        dim = dim.lower()[0]
        if dim == 'a':
            dim = -1
        if dim == 'c':
            dim = -2
        if dim == 's':
            dim = -3

    backend = utils.backend(image)

    # compute default space (mn/mx are in voxels)
    affine, shape = _get_default_native(affine, image.shape[-3:])
    space, mn, mx = _get_default_space(affine, [shape], space, bbox)
    affine, shape = (affine[0], shape[0])

    # compute default cursor (in voxels)
    if index is None:
        index = (mx + mn) / 2
    else:
        index = torch.as_tensor(index)
        index = spatial.affine_matvec(spatial.affine_inv(space), index)

    # include slice to volume matrix
    shape = tuple(((mx-mn) + 1).round().int().tolist())
    if dim == -1:
        # axial
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 1, 0, - mn[1] + 1],
                 [0, 0, 1, - index[2]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = shape[:-1]
        index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1)
    elif dim == -2:
        # coronal
        shift = [[1, 0, 0, - mn[0] + 1],
                 [0, 0, 1, - mn[2] + 1],
                 [0, 1, 0, - index[1]],
                 [0, 0, 0, 1]]
        shift = utils.as_tensor(shift, **backend)
        shape = (shape[0], shape[2])
        index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1)
    elif dim == -3:
        # sagittal
        if not transpose_sagittal:
            shift = [[0, 0, 1, - mn[2] + 1],
                     [0, 1, 0, - mn[1] + 1],
                     [1, 0, 0, - index[0]],
                     [0, 0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[2], shape[1])
            index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1)
        else:
            shift = [[0, -1, 0,   mx[1] + 1],
                     [0,  0, 1, - mn[2] + 1],
                     [1,  0, 0, - index[0]],
                     [0,  0, 0, 1]]
            shift = utils.as_tensor(shift, **backend)
            shape = (shape[1], shape[2])
            index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1)
    else:
        raise ValueError(f'Unknown dimension {dim}')

    # sample
    space = spatial.affine_rmdiv(space, shift)
    affine = spatial.affine_lmdiv(affine, space)
    affine = affine.to(**backend)
    grid = spatial.affine_grid(affine, [*shape, 1])
    *channel, s0, s1, s2 = image.shape
    imshape = (s0, s1, s2)
    image = image.reshape([1, -1, *imshape])
    image = spatial.grid_pull(image, grid[None], interpolation=interpolation,
                              bound='dct2', extrapolate=False)
    image = image.reshape([*channel, *shape])
    return ((image, index, space) if return_index and return_mat else
            (image, index) if return_index else
            (image, space) if return_mat else image)
Ejemplo n.º 14
0
    def __call__(self,
                 logaff,
                 grad=False,
                 hess=False,
                 gradmov=False,
                 hessmov=False,
                 in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                  and (self.n_iter - 1) % logplot == 0

        # jitter
        # if not hasattr(self, '_fixed'):
        #     idj = spatial.identity_grid(self.fixed.shape[-self.dim:],
        #                                 jitter=True,
        #                                 **utils.backend(self.fixed))
        #     self._fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        #     del idj
        # fixed = self._fixed
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        with torch.no_grad():
            _, gaff = linalg._expm(logaff,
                                   self.basis,
                                   grad_X=True,
                                   hess_X=False)

        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        # /!\ derivatives are not "homogeneous" (they do not have a one
        # on the bottom right): we should *not* use affine_matmul and
        # such (I only lost a day...)
        gaff = torch.matmul(gaff, self.affine_fixed)
        gaff = linalg.lmdiv(self.affine_moving, gaff)
        # haff = torch.matmul(haff, self.affine_fixed)
        # haff = linalg.lmdiv(self.affine_moving, haff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape,
                                            **utils.backend(logaff),
                                            jitter=False)
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        if not grad and not hess:
            llx = self.loss.loss(warped, fixed)
        elif not hess:
            llx, grad = self.loss.loss_grad(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
        else:
            llx, grad, hess = self.loss.loss_grad_hess(warped, fixed)
            if gradmov:
                gradmov = spatial.grid_push(grad, grid, **pullopt)
            if hessmov:
                hessmov = spatial.grid_push(hess, grid, **pullopt)
        del warped

        # compose with spatial gradients + dot product with grid
        if grad is not False or hess is not False:
            mugrad = spatial.grid_grad(self.moving, grid, **pullopt)
            grad = jg(mugrad, grad)
            if hess is not False:
                hess = jhj(mugrad, hess)
                grad, hess = regutils.affine_grid_backward(grad,
                                                           hess,
                                                           grid=self.id)
            else:
                grad = regutils.affine_grid_backward(grad)  # , grid=self.id)
            dim2 = self.dim * (self.dim + 1)
            grad = grad.reshape([*grad.shape[:-2], dim2])
            gaff = gaff[..., :-1, :]
            gaff = gaff.reshape([*gaff.shape[:-2], dim2])
            grad = linalg.matvec(gaff, grad)
            if hess is not False:
                hess = hess.reshape([*hess.shape[:-4], dim2, dim2])
                # haff = haff[..., :-1, :, :-1, :]
                # haff = haff.reshape([*gaff.shape[:-4], dim2, dim2])
                hess = gaff.matmul(hess).matmul(gaff.transpose(-1, -2))
                hess = hess.abs().sum(-1).diag_embed()
            del mugrad

        # print objective
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [ll]
        if grad is not False:
            out.append(grad)
        if hess is not False:
            out.append(hess)
        if gradmov is not False:
            out.append(gradmov)
        if hessmov is not False:
            out.append(hessmov)
        return tuple(out) if len(out) > 1 else out[0]
Ejemplo n.º 15
0
    def do_affine(self, logaff, grad=False, hess=False, in_line_search=False):
        """Forward pass for updating the affine component (nonlin is not None)"""

        sumloss = None
        sumgrad = None
        sumhess = None

        # ==============================================================
        #                     EXPONENTIATE TRANSFORMS
        # ==============================================================
        logaff0 = logaff
        aff_pos = self.affine.position[0].lower()
        if any(loss.backward for loss in self.losses):
            aff0, iaff0, gaff0, igaff0 = \
                self.affine.exp2(logaff0, grad=True,
                                 cache_result=not in_line_search)
            phi0, iphi0 = self.nonlin.exp2(cache_result=True, recompute=False)
        else:
            iaff0, igaff0, iphi0 = None, None, None
            aff0, gaff0 = self.affine.exp(logaff0, grad=True,
                                          cache_result=not in_line_search)
            phi0 = self.nonlin.exp(cache_result=True, recompute=False)

        has_printed = False
        for loss in self.losses:

            moving, fixed, factor = loss.moving, loss.fixed, loss.factor
            if loss.backward:
                phi00, aff00, gaff00 = iphi0, iaff0, igaff0
            else:
                phi00, aff00, gaff00 = phi0, aff0, gaff0

            # ----------------------------------------------------------
            # build left and right affine matrices
            # ----------------------------------------------------------
            aff_right, gaff_right = fixed.affine, None
            if aff_pos in 'fs':
                gaff_right = gaff00 @ aff_right
                gaff_right = linalg.lmdiv(self.nonlin.affine, gaff_right)
                aff_right = aff00 @ aff_right
            aff_right = linalg.lmdiv(self.nonlin.affine, aff_right)
            aff_left, gaff_left = self.nonlin.affine, None
            if aff_pos in 'ms':
                gaff_left = gaff00 @ aff_left
                gaff_left = linalg.lmdiv(moving.affine, gaff_left)
                aff_left = aff00 @ aff_left
            aff_left = linalg.lmdiv(moving.affine, aff_left)

            # ----------------------------------------------------------
            # build full transform
            # ----------------------------------------------------------
            if _almost_identity(aff_right) and fixed.shape == self.nonlin.shape:
                right = None
                phi = spatial.add_identity_grid(phi00)
            else:
                right = spatial.affine_grid(aff_right, fixed.shape)
                phi = regutils.smart_pull_grid(phi00, right)
                phi += right
            phi_right = phi
            if _almost_identity(aff_left) and moving.shape == self.nonlin.shape:
                left = None
            else:
                left = spatial.affine_grid(aff_left, self.nonlin.shape)
                phi = spatial.affine_matvec(aff_left, phi)

            # ----------------------------------------------------------
            # forward pass
            # ----------------------------------------------------------
            warped, mask = moving.pull(phi, mask=True)
            if fixed.masked:
                if mask is None:
                    mask = fixed.mask
                else:
                    mask = mask * fixed.mask

            do_print = not (has_printed or self.verbose < 3 or in_line_search
                            or loss.backward)
            if do_print:
                has_printed = True
                if moving.previewed:
                    preview = moving.pull(phi, preview=True, dat=False)
                else:
                    preview = warped
                init = spatial.affine_lmdiv(moving.affine, fixed.affine)
                if _almost_identity(init) and moving.shape == fixed.shape:
                    init = moving.dat
                else:
                    init = spatial.affine_grid(init, fixed.shape)
                    init = moving.pull(init, preview=True, dat=False)
                self.mov2fix(fixed.dat, init, preview, dim=fixed.dim,
                             title=f'(affine) {self.n_iter:03d}')

            # ----------------------------------------------------------
            # derivatives wrt moving
            # ----------------------------------------------------------
            g = h = None
            loss_args = (warped, fixed.dat)
            loss_kwargs = dict(dim=fixed.dim, mask=mask)
            state = loss.loss.get_state()
            if not grad and not hess:
                llx = loss.loss.loss(*loss_args, **loss_kwargs)
            elif not hess:
                llx, g = loss.loss.loss_grad(*loss_args, **loss_kwargs)
            else:
                llx, g, h = loss.loss.loss_grad_hess(*loss_args, **loss_kwargs)
            del loss_args, loss_kwargs
            if in_line_search:
                loss.loss.set_state(state)

            # ----------------------------------------------------------
            # chain rule -> derivatives wrt Lie parameters
            # ----------------------------------------------------------

            def compose_grad(g, h, g_mu, g_aff):
                """
                g, h : gradient/Hessian of loss wrt moving image
                g_mu : spatial gradients of moving image
                g_aff : gradient of affine matrix wrt Lie parameters
                returns g, h: gradient/Hessian of loss wrt Lie parameters
                """
                # Note that `h` can be `None`, but the functions I
                # use deal with this case correctly.
                dim = g_mu.shape[-1]
                g = jg(g_mu, g)
                h = jhj(g_mu, h)
                g, h = regutils.affine_grid_backward(g, h)
                dim2 = dim * (dim + 1)
                g = g.reshape([*g.shape[:-2], dim2])
                g_aff = g_aff[..., :-1, :]
                g_aff = g_aff.reshape([*g_aff.shape[:-2], dim2])
                g = linalg.matvec(g_aff, g)
                if h is not None:
                    h = h.reshape([*h.shape[:-4], dim2, dim2])
                    h = g_aff.matmul(h).matmul(g_aff.transpose(-1, -2))
                    # h = h.abs().sum(-1).diag_embed()
                return g, h

            if grad or hess:
                g0, g = g, None
                h0, h = h, None
                if aff_pos in 'ms':
                    g_left = regutils.smart_push(g0, phi_right, shape=self.nonlin.shape)
                    h_left = regutils.smart_push(h0, phi_right, shape=self.nonlin.shape)
                    mugrad = moving.pull_grad(left, rotate=False)
                    g_left, h_left = compose_grad(g_left, h_left, mugrad, gaff_left)
                    g, h = g_left, h_left
                if aff_pos in 'fs':
                    g_right, h_right = g0, h0
                    mugrad = moving.pull_grad(phi, rotate=False)
                    jac = spatial.grid_jacobian(phi0, right, type='disp', extrapolate=False)
                    jac = torch.matmul(aff_left[:-1, :-1], jac)
                    mugrad = linalg.matvec(jac.transpose(-1, -2), mugrad)
                    g_right, h_right = compose_grad(g_right, h_right, mugrad, gaff_right)
                    g = g_right if g is None else g.add_(g_right)
                    h = h_right if h is None else h.add_(h_right)

                if loss.backward:
                    g = g.neg_()
                sumgrad = (g.mul_(factor) if sumgrad is None else
                           sumgrad.add_(g, alpha=factor))
                if hess:
                    sumhess = (h.mul_(factor) if sumhess is None else
                               sumhess.add_(h, alpha=factor))
            sumloss = (llx.mul_(factor) if sumloss is None else
                       sumloss.add_(llx, alpha=factor))

        # TODO add regularization term
        lla = 0

        # ==============================================================
        #                           VERBOSITY
        # ==============================================================
        llx = sumloss.item()
        sumloss += lla
        sumloss += self.llv
        self.loss_value = sumloss.item()
        if self.verbose and (self.verbose > 1 or not in_line_search):
            ll = sumloss.item()
            llv = self.llv
            if in_line_search:
                line = '(search) | '
            else:
                line = '(affine) | '
            line += f'{self.n_iter:03d} | {llx:12.6g} + {llv:12.6g} + {lla:12.6g} = {ll:12.6g}'
            if not in_line_search:
                if self.ll_prev is not None:
                    gain = self.ll_prev - ll
                    # gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                    line += f' | {gain:12.6g}'
                self.all_ll.append(ll)
                self.ll_prev = ll
                self.ll_max = max(self.ll_max, ll)
                self.n_iter += 1
            print(line, end='\r')

        # ==============================================================
        #                           RETURN
        # ==============================================================
        out = [sumloss]
        if grad:
            out.append(sumgrad)
        if hess:
            out.append(sumhess)
        return tuple(out) if len(out) > 1 else out[0]
Ejemplo n.º 16
0
def ffd(source,
        target,
        grid_shape=10,
        group='SE',
        image_loss=None,
        def_loss=None,
        pull=None,
        preproc=True,
        max_iter=1000,
        device=None,
        origin='center',
        init=None,
        lr=1e-4,
        optim_affine=True,
        scheduler=ReduceLROnPlateau):
    """FFD (= cubic spline) registration
    
    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across 
       the channel dimension. The loss function is then responsible 
       for unstacking the tensor and computing the appropriate 
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as 
       well as the same interpolation order. The advantage is that 
       it simplifies the signature of this function.

    Parameters
    ----------
    source : tensor or (tensor, affine)
    target : tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr : float, default=0.1
    scheduler : Scheduler, default=ReduceLROnPlateau

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dft')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    if preproc:
        source = rescale(source)
        target = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        affine_parameters = torch.as_tensor(init, **backend).clone().detach()
        affine_parameters = affine_parameters.reshape([batch, nb_prm])
    else:
        affine_parameters = torch.zeros([batch, nb_prm], **backend)
    affine_parameters = nn.Parameter(affine_parameters,
                                     requires_grad=optim_affine)
    grid_shape = core.pyutils.make_list(grid_shape, dim)
    grid_parameters = torch.zeros([batch, *grid_shape, dim], **backend)
    grid_parameters = nn.Parameter(grid_parameters, requires_grad=True)

    def pull(q, grid):
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
        grid = spatial.affine_matvec(aff[expd], grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    def exp(prm):
        disp = spatial.resize_grid(prm,
                                   type='displacement',
                                   shape=target.shape[2:],
                                   interpolation=3,
                                   bound='dft')
        grid = disp + spatial.identity_grid(target.shape[2:], **backend)
        return disp, grid

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(def_loss):
        def_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if def_loss is None else def_loss
        def_loss = lambda x: factor * def_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    opt_prm = [{
        'params': affine_parameters,
        'lr': lr[1]
    }, {
        'params': grid_parameters,
        'lr': lr[0]
    }] if optim_affine else [grid_parameters]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    if scheduler is not None:
        scheduler = scheduler(optim, cooldown=5)


#     with torch.no_grad():
#         disp, grid = exp(grid_parameters)
#         moved = pull(affine_parameters, grid)
#         plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu())
#         plt.show()

# Optim loop
    loss_val = core.constants.inf
    loss_avg = 0
    for n_iter in range(max_iter):

        loss_val0 = loss_val
        zero_grad_([affine_parameters, grid_parameters])
        disp, grid = exp(grid_parameters)
        moved = pull(affine_parameters, grid)
        loss_val = image_loss(moved, target) + def_loss(disp[0])
        loss_val.backward()
        optim.step()

        with torch.no_grad():
            loss_avg += loss_val

        if n_iter % 10 == 0:
            #             print(affine_parameters)
            #             plt.imshow(torch.cat([target, moved, source], dim=1).detach().cpu())
            #             plt.show()

            loss_avg /= 10
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(loss_avg)
                else:
                    scheduler.step()

            with torch.no_grad():
                if n_iter % 10 == 0:
                    print('{:4d} {:12.6f} | lr={:g}'.format(
                        n_iter, loss_avg.item(), optim.param_groups[0]['lr']),
                          end='\r')

            loss_avg = 0

    print('')
    with torch.no_grad():
        moved = pull(affine_parameters, grid)
        aff = core.linalg.expm(affine_parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return affine_parameters, aff, grid_parameters, moved
Ejemplo n.º 17
0
 def _index_from_cursor(self, x, y, image, n_ax):
     p = utils.as_tensor([x, y, 0])
     mat = image._mats[n_ax]
     self.index = spatial.affine_matvec(mat, p)
Ejemplo n.º 18
0
def diffeo(source,
           target,
           group='SE',
           image_loss=None,
           vel_loss=None,
           pull=None,
           preproc=False,
           max_iter=1000,
           device=None,
           origin='center',
           init=None,
           lr=1e-4,
           optim_affine=True,
           scheduler=ReduceLROnPlateau):
    """

    Parameters
    ----------
    source : path or tensor or (tensor, affine)
    target : path or tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    image_loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr: float, default=1e-4
    optim_affine : bool, default=True

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    vel : (D+1, D+1) tensor
        Initial velocity of the diffeomorphic transform.
        The full warp is `(aff @ aff_src).inv() @ aff_trg @ exp(vel)`
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    source = rescale(source)
    targe = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff = target_aff.clone()
        source_aff = source_aff.clone()
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=optim_affine)
    velocity = torch.zeros([batch, *target.shape[2:], dim], **backend)
    velocity = nn.Parameter(velocity, requires_grad=True)

    def pull(q, vel):
        grid = spatial.exp(vel)
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        grid = spatial.affine_matvec(aff, grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(vel_loss):
        vel_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if vel_loss is None else vel_loss
        vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \
              if optim_affine else [velocity]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    if scheduler is not None:
        scheduler = scheduler(optim, cooldown=5)

    # Optim loop
    loss_val = core.constants.inf
    loss_avg = 0
    for n_iter in range(1, max_iter + 1):

        loss_val0 = loss_val
        optim.zero_grad(set_to_none=True)
        moved = pull(parameters, velocity)
        loss_val = image_loss(moved, target) + vel_loss(velocity)
        loss_val.backward()
        optim.step()
        with torch.no_grad():
            loss_avg += loss_val

        if n_iter % 10 == 0:
            loss_avg /= 10
            if scheduler is not None:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(loss_avg)
                else:
                    scheduler.step()

            with torch.no_grad():
                if n_iter % 10 == 0:
                    print('{:4d} {:12.6f} | lr={:g}'.format(
                        n_iter, loss_avg.item(), optim.param_groups[0]['lr']),
                          end='\r')

            loss_avg = 0

    print('')
    with torch.no_grad():
        moved = pull(parameters, velocity)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return parameters, aff, velocity, moved
Ejemplo n.º 19
0
def diffeo(source, target, group='SE', origin='center',
           image_loss=None, vel_loss=None, pull=None, optim_affine=True,
           max_iter=1000, lr=0.1, min_lr=1e-7, init=None, device=None):
    """Diffeomorphic registration

    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across
       the channel dimension. The loss function is then responsible
       for unstacking the tensor and computing the appropriate
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as
       well as the same interpolation order. The advantage is that
       it simplifies the signature of this function.

    Parameters
    ----------
    source :  tensor or (tensor, affine)
        The source (moving) image, with shape (batch, channel, *spatial).
    target : tensor or (tensor, affine)
        The target (fixed) image, with shape (batch, channel, *spatial).
    group : {'tr', 'rot', 'rigid', 'sim', 'lin', 'aff'}, default='rigid'
        Affine sub-group to optimize.
    origin : {'native', 'center'}, default='center'
        Whether to rotate about the origin of the world-space ('native')
        or the center of the target field-of-view ('center').
        When the origin of the world-space is far off (say you are
        registering smaller blocks cropped from a larger MRI), it can
        be beneficiary to rotate about the center of the FOV.
    image_loss : callable(mov, fix) -> loss, default=MutualInfoLoss()
        A loss function that takestwo  inputs of shape
        (batch, channel, *spatial).
    vel_loss : float or callable(mov, fix) -> loss, default=BendingLoss()
        Either a factor to muultiply the bending loss with or a loss 
        function that takes two inputs of shape (batch, channel, *spatial).
    pull : dict
        interpolation : int, default=1
            Interpolation order
        bound : bound_like, default='dct2'
            Boundary condition
        extrapolate : bool, default=False
            Extrapolate out-of-bound data using the boundary conditions.
    max_iter : int, default=1000
        Maximum number of iterations
    lr : float, default=0.1
        Initial learning rate.
    min_lr : float, default=1e-7
        Minimum learning rate. The optimization is stopped once this
        learning rate is reached.
    device : {'cpu', 'cuda', 'cuda:<id>'}, optional
        Backend to use
    init : ([batch], nb_prm) tensor_like, default=0
        Initial guess for the affine parameters.

    Returns
    -------
    q : (batch, nb_prm) tensor
        Parameters
    aff : (batch, D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    vel : (batch, *shape, D) tensor
        Initial velocity
    moved : tensor
        Source image moved to target space.
    """
    group = affine_group_converter(group)
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    dim = source.dim() - 2

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff = target_aff.clone()
        source_aff = source_aff.clone()
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=optim_affine)
    velocity = torch.zeros([batch, *target.shape[2:], dim], **backend)
    velocity = nn.Parameter(velocity, requires_grad=True)

    def pull(q, vel):
        grid = spatial.exp(vel)
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        grid = spatial.affine_matvec(aff, grid)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if not callable(image_loss):
        image_loss_fn = nni.MutualInfoLoss()
        factor = 1. if image_loss is None else image_loss
        image_loss = lambda x, y: factor * image_loss_fn(x, y)

    if not callable(vel_loss):
        vel_loss_fn = nni.BendingLoss(bound='dft')
        factor = 1. if vel_loss is None else vel_loss
        vel_loss = lambda x: factor * vel_loss_fn(core.utils.last2channel(x))

    lr = core.utils.make_list(lr, 2)
    min_lr = core.utils.make_list(min_lr, 2)
    opt_prm = [{'params': parameters}, {'params': velocity, 'lr': lr[1]}] \
              if optim_affine else [velocity]
    optim = torch.optim.Adam(opt_prm, lr=lr[0])
    scheduler = ReduceLROnPlateau(optim)

    def forward():
        moved = pull(parameters, velocity)
        loss_val = image_loss(moved, target) + vel_loss(velocity)
        return loss_val

    # Optim loop
    loss_avg = 0
    for n_iter in range(1, max_iter + 1):

        optim.zero_grad(set_to_none=True)
        loss_val = forward()
        loss_val.backward()
        optim.step(forward)

        with torch.no_grad():
            loss_avg += loss_val
            if n_iter % 10 == 0:
                loss_avg /= 10
                scheduler.step(loss_avg)

                print('{:4d} {:12.6f} | lr={:g} '
                      .format(n_iter, loss_avg.item(),
                              optim.param_groups[0]['lr']),
                      end='\r')
                loss_avg = 0

        if (optim.param_groups[0]['lr'] < min_lr[0] and
                (len(optim.param_groups) == 1 or
                 optim.param_groups[1]['lr'] < min_lr[1])):
            print('\nConverged.')
            break

    print('')
    with torch.no_grad():
        moved = pull(parameters, velocity)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
    return (parameters.detach(),
            aff.detach(),
            velocity.detach(),
            moved.detach())
Ejemplo n.º 20
0
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None,
           interpolation=1, bound='dct2', extrapolate=False, device=None,
           verbose=True):
    """Apply the linear and non-linear components of the transform and
    reslice the image to the target space.

    Notes
    -----
    .. The shape and general orientation of the moving image is kept untouched.
    .. The linear transform is composed with the original orientation matrix.
    .. The non-linear component is "wrapped" in the input space, where it
       is applied.
    .. This function writes a new file (it does not modify the input files
       in place).

    Parameters
    ----------
    moving : ImageFile
        An object describing a moving image.
    fname : list of str
        Output filename for each input file of the moving image
        (since images can be encoded over multiple volumes)
    like : ImageFile
        An object describing the target space
    inv : bool, default=False
        True if we are warping the fixed image to the moving space.
        In the case, `moving` should be a `FixedImageFile` and `like` a
        `MovingImageFile`. Else it should be a `MovingImageFile` and `'like`
        a `FixedImageFile`.
    lin : (4, 4) tensor, optional
        Linear (or rather affine) transformation
    nonlin : dict, optional
        Non-linear displacement field, with keys:
        disp : (..., 3) tensor
            Displacement field (in voxels)
        affine : (4, 4) tensor
            Orientation matrix of the displacement field
    interpolation : int, default=1
    bound : str, default='dct2'
    extrapolate : bool, default = False
    device : default='cpu'

    """
    nonlin = nonlin or dict(disp=None, affine=None)
    prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate)

    moving_affine = moving.affine.to(device)
    fixed_affine = like.affine.to(device)

    if inv:
        # affine-corrected fixed space
        if lin is not None:
            fix2lin = affine_matmul(lin, fixed_affine)
        else:
            fix2lin = fixed_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param to moving
            nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(moving_affine, fix2lin)
            g = affine_grid(g, like.shape)

    else:
        # affine-corrected moving space
        if lin is not None:
            mov2nlin = affine_matmul(lin, moving_affine)
        else:
            mov2nlin = moving_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param voxels to moving voxels (warps moving to fixed)
            nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(mov2nlin, fixed_affine)
            g = affine_grid(g, like.shape)

    if moving.type == 'labels':
        prm['interpolation'] = 0
    for file, ofname in zip(moving.files, fname):
        if verbose:
            print(f'Resliced:   {file.fname}\n'
                  f'         -> {ofname}')
        dat = io.volumes.loadf(file.fname, rand=True, device=device)
        dat = dat.reshape([*file.shape, file.channels])
        if g is not None:
            dat = utils.movedim(dat, -1, 0)
            dat = pull(dat, g, **prm)
            dat = utils.movedim(dat, 0, -1)
        io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
Ejemplo n.º 21
0
    def forward():
        """Forward pass up to the loss"""

        loss = 0

        # affine matrix
        A = None
        for trf in options.transformations:
            trf.update()
            if isinstance(trf, struct.Linear):
                q = trf.optdat.to(**backend)
                # print(q.tolist())
                B = trf.basis.to(**backend)
                A = linalg.expm(q, B)
                if torch.is_tensor(trf.shift):
                    # include shift
                    shift = trf.shift.to(**backend)
                    eye = torch.eye(options.dim, **backend)
                    A = A.clone()  # needed because expm is a custom autograd.Function
                    A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift)
                for loss1 in trf.losses:
                    loss += loss1.call(q)
                break

        # non-linear displacement field
        d = None
        d_aff = None
        for trf in options.transformations:
            if not trf.isfree():
                continue
            if isinstance(trf, struct.FFD):
                d = trf.dat.to(**backend)
                d = ffd_exp(d, trf.shape, returns='disp')
                for loss1 in trf.losses:
                    loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break
            elif isinstance(trf, struct.Diffeo):
                d = trf.dat.to(**backend)
                if not trf.smalldef:
                    # penalty on velocity fields
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d = spatial.exp(d[None], displacement=True)[0]
                if trf.smalldef:
                    # penalty on exponentiated transform
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break

        # loop over image pairs
        for match in options.losses:
            if not match.fixed:
                continue
            nb_levels = len(match.fixed.dat)
            prm = dict(interpolation=match.moving.interpolation,
                       bound=match.moving.bound,
                       extrapolate=match.moving.extrapolate)
            # loop over pyramid levels
            for moving, fixed in zip(match.moving.dat, match.fixed.dat):
                moving_dat, moving_aff = moving
                fixed_dat, fixed_aff = fixed

                moving_dat = moving_dat.to(**backend)
                moving_aff = moving_aff.to(**backend)
                fixed_dat = fixed_dat.to(**backend)
                fixed_aff = fixed_aff.to(**backend)

                # affine-corrected moving space
                if A is not None:
                    Ms = affine_matmul(A, moving_aff)
                else:
                    Ms = moving_aff

                if d is not None:
                    # fixed to param
                    Mt = affine_lmdiv(d_aff, fixed_aff)
                    if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]):
                        g = smalldef(d)
                    else:
                        g = affine_grid(Mt, fixed_dat.shape[1:])
                        g = g + pull_grid(d, g)
                    # param to moving
                    Ms = affine_lmdiv(Ms, d_aff)
                    g = affine_matvec(Ms, g)
                else:
                    # fixed to moving
                    Mt = fixed_aff
                    Ms = affine_lmdiv(Ms, Mt)
                    g = affine_grid(Ms, fixed_dat.shape[1:])

                # pull moving image
                warped_dat = pull(moving_dat, g, **prm)
                loss += match.call(warped_dat, fixed_dat) / float(nb_levels)

                # import matplotlib.pyplot as plt
                # plt.subplot(1, 2, 1)
                # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.subplot(1, 2, 2)
                # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.show()

        return loss
Ejemplo n.º 22
0
def load_transforms(s):
    """Initialize transforms"""
    device = torch.device(s.device)

    def reshape3d(dat, channels=None, dim=3):
        """Reshape as (*spatial) or (C, *spatial) or (*spatial, C).
        `channels` should be in ('first', 'last', None).
        """
        while len(dat.shape) > dim:
            if dat.shape[-1] == 1:
                dat = dat[..., 0]
                continue
            elif dat.shape[dim] == 1:
                dat = dat[:, :, :, 0, ...]
                continue
            else:
                break
        if len(dat.shape) > dim + bool(channels):
            raise ValueError('Too many channel dimensions')
        if channels and len(dat.shape) == dim:
            dat = dat[..., None]
        if channels == 'first':
            dat = utils.movedim(dat, -1, 0)
        return dat

    # compute mean space
    #   it is used to define the space of the nonlinear transform, but
    #   also to shift the center of rotation of the linear transform.
    all_affines = []
    all_shapes = []
    all_affines_fixed = []
    all_shapes_fixed = []
    for loss in s.losses:
        if isinstance(loss, struct.NoLoss):
            continue
        if getattr(loss, 'exclude', False):
            continue
        all_shapes_fixed.append(loss.fixed.shape)
        all_affines_fixed.append(loss.fixed.affine)
        all_shapes.append(loss.fixed.shape)
        all_affines.append(loss.fixed.affine)
        all_shapes.append(loss.moving.shape)
        all_affines.append(loss.moving.affine)
    affine0, shape0 = mean_space(all_affines, all_shapes,
                                 pad=s.pad, pad_unit=s.pad_unit)
    affinef, shapef = mean_space(all_affines_fixed, all_shapes_fixed,
                                 pad=s.pad, pad_unit=s.pad_unit)
    backend = dict(dtype=affine0.dtype, device=affine0.device)

    for trf in s.transformations:
        for reg in trf.losses:
            if isinstance(reg.factor, (list, tuple)):
                reg.factor = [f * trf.factor for f in reg.factor]
            else:
                reg.factor = reg.factor * trf.factor
        if isinstance(trf, struct.Linear):
            # Affine
            if isinstance(trf.init, str):
                trf.dat = io.transforms.loadf(trf.init, dtype=torch.float32,
                                              device=device)
            else:
                trf.dat = torch.zeros(trf.nb_prm(3), dtype=torch.float32,
                                      device=device)
            if trf.shift:
                shift = torch.as_tensor(shapef, **backend) * 0.5
                trf.shift = -spatial.affine_matvec(affinef, shift)
            else:
                trf.shift = 0.
        else:
            affine, shape = (affine0, shape0)
            trf.pyramid = list(sorted(trf.pyramid))
            max_level = max(trf.pyramid)
            factor = 2**(max_level-1)
            affine, shape = affine_resize(affine, shape, 1/factor)

            # FFD/Diffeo
            if isinstance(trf.init, str):
                f = io.volumes.map(trf.init)
                trf.dat = reshape3d(f.loadf(dtype=torch.float32, device=device),
                                    'last')
                if len(trf.dat) != trf.dim:
                    raise ValueError('Field should have 3 channels')
                factor = [int(s//g) for g, s in zip(trf.shape[:-1], shape)]
                trf.affine, trf.shape = affine_resize(trf.affine, trf.shape[:-1], factor)
            else:
                trf.dat = torch.zeros([*shape, trf.dim], dtype=torch.float32,
                                      device=device)
                trf.affine = affine
                trf.shape = shape
Ejemplo n.º 23
0
def crop(inp,
         size=None,
         center=None,
         space='vx',
         like=None,
         output=None,
         transform=None):
    """Crop a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, optional
        Size of the patch to extract.
        Its unit and axes are defined by `units` and `layout`.
    center : [sequence of] int, optional
        Coordinate of the center of the patch.
        Its unit and axes are defined by `units` and `layout`.
        By default, the center of the FOV is used.
    space : [sequence of] {'vox', 'ras'}, default='vox'
        The space in which the `size` and `center` parameters are expressed.
    like : str or (tensor, tensor), optional
        Reference patch.
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Input or output filename(s) of the corresponding transforms.
        Not written by default.
        If a transform is provided and all other parameters
        (i.e., `size` and `like`) are None, the transform is considered
        as an input transform to apply.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None
    transform_in = False

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(numpy=True), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.crop{ext}'
        dir, base, ext = py.fileparts(fname)
    dat, aff0 = inp
    dim = aff0.shape[-1] - 1
    shape0 = dat.shape[:dim]

    if size and like:
        raise ValueError('Cannot use both `size` and `like`.')

    # --- Open reference and compute size/center ---
    if like:
        like_is_file = isinstance(like, str)
        if like_is_file:
            f = io.volumes.map(like)
            like = (f.shape, f.affine)
        like_shape, like_aff = like
        if torch.is_tensor(like_shape):
            like_shape = like_shape.shape
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- Open transformation file and compute size/center ---
    elif not size:
        if not transform:
            raise ValueError('At least one of size/like/transform must '
                             'be provided')
        transform_in = True
        t = io.transforms.map(transform)
        if not isinstance(t, io.transforms.LinearTransformArray):
            raise TypeError('Expected an LTA file')
        like_aff, like_shape = t.destination_space()
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- use center of the FOV ---
    if not torch.is_tensor(center) and not center:
        center = torch.as_tensor(shape0[:dim], dtype=torch.float) * 0.5

    # --- convert size/center to voxels ---
    size = utils.make_vector(size, dim, dtype=torch.long)
    center = utils.make_vector(center, dim, dtype=torch.float)
    space_size, space_center = py.make_list(space, 2)
    if space_center.lower() == 'ras':
        center = spatial.affine_matvec(spatial.affine_inv(aff0), center)
    if space_size.lower() == 'ras':
        perm = spatial.affine_to_layout(aff0)[:, 0]
        size = size[perm.long()]
        size = size / spatial.voxel_size(aff0)

    # --- compute first/last ---
    center = center.float()
    size = size.ceil().long()
    first = (center - size.float() / 2).round().long()
    last = (first + size).tolist()
    first = [max(f, 0) for f in first.tolist()]
    last = [min(l, s) for l, s in zip(last, shape0[:dim])]
    verb = 'Cropping patch ['
    verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)])
    verb += f'] from volume with shape {shape0[:dim]}'
    print(verb)
    slicer = tuple(slice(f, l) for f, l in zip(first, last))

    # --- do crop ---
    dat = dat[slicer]
    aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer)
    shape = dat.shape[:dim]

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if transform and not transform_in:
        if is_file:
            transform = transform.format(dir=dir or '.',
                                         base=base,
                                         ext=ext,
                                         sep=os.path.sep)
        else:
            transform = transform.format(sep=os.path.sep)
        trf = io.transforms.LinearTransformArray(transform, 'w')
        trf.set_source_space(aff0, shape0)
        trf.set_destination_space(aff, shape)
        trf.set_metadata({
            'src': {
                'filename': fname
            },
            'dst': {
                'filename': output
            },
            'type': 1
        })  # RAS_TO_RAS
        trf.set_fdata(torch.eye(4))
        trf.save()

    if is_file:
        return output
    else:
        return dat, aff
Ejemplo n.º 24
0
def vexp(inp,
         type='displacement',
         unit='voxel',
         inverse=False,
         bound='dft',
         steps=8,
         device=None,
         output=None):
    """Exponentiate a stationary velocity fields.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    type : {'displacement', 'transformation'}, default='displacement'
        Whether to return a displacement field (coord-to-shift) or a
        transformation field (coord-to-coord).
    unit : {'voxel', 'mm'}, default='voxel'
        Whether to return displacement/coordinates in voxel or in mm.
        If mm, the input orientation matrix is used to convert voxels to mm.
    inverse : bool, default=False
        Whether to return the inverse field.
    bound : str, default='dft'
        Boundary conditions.
    steps : int, default=8
        Number of scaling and squaring steps.
    device : str, optional
        Device to use.
    output : str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.vexp{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.

    Returns
    -------
    output : str or (tensor, tensor)
        If the input is a path, the output path is returned.
        Else, the output tensor and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(device=device), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.vexp{ext}'
        dir, base, ext = py.fileparts(fname)
    else:
        if torch.is_tensor(inp):
            inp = (inp.clone(), spatial.affine_default(shape=inp.shape[:3]))
    dat, aff = inp
    dat = dat.to(device=device)
    aff = aff.to(device=device)

    # exponentiate
    dat = spatial.exp(dat[None],
                      inverse=inverse,
                      steps=steps,
                      bound=bound,
                      inplace=True,
                      displacement=(type.lower()[0] == 'd'))[0]
    if unit == 'mm':
        # if type.lower()[0] == 'd':
        #     vx = spatial.voxel_size(aff)
        #     dat *= vx
        # else:
        dat = spatial.affine_matvec(aff, dat)

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff.cpu())
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff.cpu())

    if is_file:
        return output
    else:
        return dat, aff
Ejemplo n.º 25
0
    def __call__(self, logaff, grad=False, hess=False, in_line_search=False):
        """
        logaff : (..., nb) tensor, Lie parameters
        grad : Whether to compute and return the gradient wrt `logaff`
        hess : Whether to compute and return the Hessian wrt `logaff`
        gradmov : Whether to compute and return the gradient wrt `moving`
        hessmov : Whether to compute and return the Hessian wrt `moving`

        Returns
        -------
        ll : () tensor, loss value (objective to minimize)
        g : (..., logaff) tensor, optional, Gradient wrt Lie parameters
        h : (..., logaff) tensor, optional, Hessian wrt Lie parameters
        gm : (..., *spatial, dim) tensor, optional, Gradient wrt moving
        hm : (..., *spatial, ?) tensor, optional, Hessian wrt moving

        """
        # This loop performs the forward pass, and computes
        # derivatives along the way.

        # select correct gradient mode
        if grad:
            logaff.requires_grad_()
            if logaff.grad is not None:
                logaff.grad.zero_()
        if grad and not torch.is_grad_enabled():
            with torch.enable_grad():
                return self(logaff, grad, in_line_search=in_line_search)
        elif not grad and torch.is_grad_enabled():
            with torch.no_grad():
                return self(logaff, grad, in_line_search=in_line_search)

        pullopt = dict(bound=self.bound, extrapolate=self.extrapolate)

        logplot = max(self.max_iter // 20, 1)
        do_plot = (not in_line_search) and self.plot \
                   and (self.n_iter - 1) % logplot == 0

        # jitter
        # idj = spatial.identity_grid(self.fixed.shape[-self.dim:], jitter=True,
        #                             **utils.backend(self.fixed))
        # fixed = spatial.grid_pull(self.fixed, idj, **pullopt)
        # del idj
        fixed = self.fixed

        # forward
        if not torch.is_tensor(self.basis):
            self.basis = spatial.affine_basis(self.basis, self.dim,
                                              **utils.backend(logaff))
        aff = linalg.expm(logaff, self.basis)
        aff = spatial.affine_matmul(aff, self.affine_fixed)
        aff = spatial.affine_lmdiv(self.affine_moving, aff)
        if self.id is None:
            shape = self.fixed.shape[-self.dim:]
            self.id = spatial.identity_grid(shape, **utils.backend(logaff))
        grid = spatial.affine_matvec(aff, self.id)
        warped = spatial.grid_pull(self.moving, grid, **pullopt)
        if do_plot:
            iscat = isinstance(self.loss, losses.Cat)
            plt.mov2fix(self.fixed,
                        self.moving,
                        warped,
                        cat=iscat,
                        dim=self.dim)

        # gradient/Hessian of the log-likelihood in observed space
        llx = self.loss.loss(warped, fixed)
        del warped

        # print objective
        lll = llx
        llx = llx.item()
        ll = llx
        if self.verbose and not in_line_search:
            self.n_iter += 1
            if self.ll_prev is None:
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g}',
                    end='\n')
            else:
                gain = (self.ll_prev - ll) / max(abs(self.ll_max - ll), 1e-8)
                print(
                    f'{self.n_iter:03d} | {llx:12.6g} + {0:12.6g} = {ll:12.6g} | {gain:12.6g}',
                    end='\n')
            self.ll_prev = ll
            self.ll_max = max(self.ll_max, ll)

        out = [lll]
        if grad is not False:
            lll.backward()
            grad = logaff.grad.clone()
            out.append(grad)
        logaff.requires_grad_(False)
        return tuple(out) if len(out) > 1 else out[0]
Ejemplo n.º 26
0
    def compose(self,
                orient_in,
                deformation,
                orient_mean,
                affine=None,
                orient_out=None,
                shape_out=None):
        """Composes a deformation defined in a mean space to an image space.

        Parameters
        ----------
        orient_in : (4, 4) tensor
            Orientation of the input image
        deformation : (*shape_mean, 3) tensor
            Random deformation
        orient_mean : (4, 4) tensor
            Orientation of the mean space (where the deformation is)
        affine : (4, 4) tensor, default=identity
            Random affine
        orient_out : (4, 4) tensor, default=orient_in
            Orientation of the output image
        shape_out : sequence[int], default=shape_mean
            Shape of the output image

        Returns
        -------
        grid : (*shape_out, 3)
            Voxel-to-voxel transform

        """
        if orient_out is None:
            orient_out = orient_in
        if shape_out is None:
            shape_out = deformation.shape[:-1]
        if affine is None:
            affine = torch.eye(4,
                               4,
                               device=orient_in.device,
                               dtype=orient_in.dtype)
        shape_mean = deformation.shape[:-1]

        orient_in, affine, deformation, orient_mean, orient_out \
            = utils.to_max_backend(orient_in, affine, deformation, orient_mean, orient_out)
        backend = utils.backend(deformation)
        eye = torch.eye(4, **backend)

        # Compose deformation on the right
        right_affine = spatial.affine_lmdiv(orient_mean, orient_out)
        if not (shape_mean == shape_out and right_affine.all_close(eye)):
            # the mean space and native space are not the same
            # we must compose the diffeo with a dense affine transform
            # we write the diffeo as an identity plus a displacement
            # (id + disp)(aff) = aff + disp(aff)
            # -------
            # to displacement
            deformation = deformation - spatial.identity_grid(
                deformation.shape[:-1], **backend)
            trf = spatial.affine_grid(right_affine, shape_out)
            deformation = spatial.grid_pull(utils.movedim(deformation, -1,
                                                          0)[None],
                                            trf[None],
                                            bound='dft',
                                            extrapolate=True)
            deformation = utils.movedim(deformation[0], 0, -1)
            trf = trf + deformation  # add displacement

        # Compose deformation on the left
        #   the output of the diffeo(right) are mean_space voxels
        #   we must compose on the left with `in\(aff(mean))`
        # -------
        left_affine = spatial.affine_matmul(spatial.affine_inv(orient_in),
                                            affine)
        left_affine = spatial.affine_matmul(left_affine, orient_mean)
        trf = spatial.affine_matvec(left_affine, trf)

        return trf
Ejemplo n.º 27
0
def crop(inp,
         size=None,
         center=None,
         space='vx',
         like=None,
         bbox=False,
         output=None,
         transform=None):
    """Crop a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, optional
        Size of the patch to extract.
        Its unit and axes are defined by `units` and `layout`.
    center : [sequence of] int, optional
        Coordinate of the center of the patch.
        Its unit and axes are defined by `units` and `layout`.
        By default, the center of the FOV is used.
    space : [sequence of] {'vox', 'ras'}, default='vox'
        The space in which the `size` and `center` parameters are expressed.
    bbox : bool or float, default=False
        Crop at the bounding box of `inp > threshold`.
            If `bbox` is a float, it is the threshold to use.
            If `bbox` is `True`, the threshold is 0.
    like : str or (tensor, tensor), optional
        Reference patch.
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Input or output filename(s) of the corresponding transforms.
        Not written by default.
        If a transform is provided and all other parameters
        (i.e., `size` and `like`) are None, the transform is considered
        as an input transform to apply.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None
    transform_in = False
    use_bbox = bool(bbox or isinstance(bbox, float))

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(numpy=True) if use_bbox else f, f.affine)
        if output is None:
            output = '{dir}{sep}{base}.crop{ext}'
        dir, base, ext = py.fileparts(fname)
    dat, aff0 = inp
    dim = aff0.shape[-1] - 1
    shape0 = dat.shape[:dim]
    layout0 = spatial.affine_to_layout(aff0)

    # save input space in case we reorient later
    aff00 = aff0
    shape00 = shape0

    if bool(size) + bool(like) + bool(bbox or isinstance(bbox, float)) > 1:
        raise ValueError('Can only use one of `size`, `like` and `bbox`.')

    # --- Open reference and compute size/center ---
    if like:
        like_is_file = isinstance(like, str)
        if like_is_file:
            f = io.volumes.map(like)
            like = (f.shape, f.affine)
        like_shape, like_aff = like
        like_layout = spatial.affine_to_layout(like_aff)
        if (layout0 != like_layout).any():
            aff0, dat = spatial.affine_reorient(aff0, dat, like_layout)
            shape0 = dat.shape[:dim]
        if torch.is_tensor(like_shape):
            like_shape = like_shape.shape
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)
        space = 'vox'

    elif bbox or isinstance(bbox, float):
        if bbox is True:
            bbox = 0.
        mask = torch.as_tensor(dat > bbox)
        while mask.dim() > 3:
            mask = mask.any(dim=-1)
        mins = []
        maxs = []
        for d in range(dim):
            n = mask.shape[d]
            idx = utils.movedim(mask, d,
                                0).reshape([n, -1
                                            ]).any(-1).nonzero(as_tuple=False)
            mins.append(idx.min())
            maxs.append(idx.max())
        mins = utils.as_tensor(mins)
        maxs = utils.as_tensor(maxs)
        size = maxs + 1 - mins
        center = (maxs + 1 + mins).float() / 2
        space = 'vox'
        del mask

    # --- Open transformation file and compute size/center ---
    elif not size:
        if not transform:
            raise ValueError('At least one of size/like/transform must '
                             'be provided')
        transform_in = True
        t = io.transforms.map(transform)
        if not isinstance(t, io.transforms.LinearTransformArray):
            raise TypeError('Expected an LTA file')
        like_aff, like_shape = t.destination_space()
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- use center of the FOV ---
    if not torch.is_tensor(center) and not center:
        center = torch.as_tensor(shape0[:dim], dtype=torch.float)
        center = center.sub_(1).mul_(0.5)

    # --- convert size/center to voxels ---
    size = utils.make_vector(size, dim, dtype=torch.long)
    center = utils.make_vector(center, dim, dtype=torch.float)
    space_size, space_center = py.make_list(space, 2)
    if space_center.lower() == 'ras':
        center = spatial.affine_matvec(spatial.affine_inv(aff0), center)
    if space_size.lower() == 'ras':
        perm = spatial.affine_to_layout(aff0)[:, 0]
        size = size[perm.long()]
        size = size / spatial.voxel_size(aff0)

    # --- compute first/last ---
    center = center.float()
    size = (size.ceil() if size.dtype.is_floating_point else size).long()
    first = center - size.float().sub_(1).mul_(0.5)
    first = first.round().long()
    last = (first + size).tolist()
    first = [max(f, 0) for f in first.tolist()]
    last = [min(l, s) for l, s in zip(last, shape0[:dim])]
    verb = 'Cropping patch ['
    verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)])
    verb += f'] from volume with shape {shape0[:dim]}'
    print(verb)
    slicer = tuple(slice(f, l) for f, l in zip(first, last))

    # --- do crop ---
    if use_bbox:
        dat = dat.numpy()
    dat = dat[slicer]
    if not torch.is_tensor(dat):
        dat = dat.data(numpy=True)
    aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer)
    shape = dat.shape[:dim]

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if transform and not transform_in:
        if is_file:
            transform = transform.format(dir=dir or '.',
                                         base=base,
                                         ext=ext,
                                         sep=os.path.sep)
        else:
            transform = transform.format(sep=os.path.sep)
        trf = io.transforms.LinearTransformArray(transform, 'w')
        trf.set_source_space(aff00, shape00)
        trf.set_destination_space(aff, shape)
        trf.set_metadata({
            'src': {
                'filename': fname
            },
            'dst': {
                'filename': output
            },
            'type': 1
        })  # RAS_TO_RAS
        trf.set_fdata(torch.eye(4))
        trf.save()

    if is_file:
        return output
    else:
        return dat, aff
Ejemplo n.º 28
0
def info(inp, meta=None, stat=False):
    """Print information on a volume.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    meta : sequence of str
        List of fields to print.
        By default, a list of common fields is used.
    stat : bool, default=False
        Compute intensity statistics

    """

    meta = meta or []
    metadata = {}
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        if stat:
            inp = (f.fdata(), f.affine)
        else:
            inp = (f.shape, f.affine)
        metadata = f.metadata(meta)
        metadata['dtype'] = f.dtype
    dat, aff = inp
    if not is_file:
        metadata['dtype'] = dat.dtype
    if torch.is_tensor(dat):
        shape = dat.shape
    else:
        shape = dat

    pad = max([0] + [len(m) for m in metadata.keys()])
    if not meta:
        more_fields = ['shape', 'layout', 'filename']
        pad = max(pad, max(len(f) for f in more_fields))
    title = lambda tag: ('{tag:' + str(pad) + 's}').format(tag=tag)

    if not meta:
        if is_file:
            print(f'{title("filename")} : {fname}')
        print(f'{title("shape")} : {tuple(shape)}')
        layout = spatial.affine_to_layout(aff)
        layout = spatial.volume_layout_to_name(layout)
        print(f'{title("layout")} : {layout}')
        center = torch.as_tensor(shape[:3], dtype=torch.float) / 2
        center = spatial.affine_matvec(aff, center)
        print(f'{title("center")} : {tuple(center.tolist())} mm (RAS)')
        if stat and torch.is_tensor(dat):
            chandim = list(range(3, dat.ndim))
            if not chandim:
                vmin = dat.min().tolist()
                vmax = dat.max().tolist()
                vmean = dat.mean().tolist()
            else:
                dat1 = dat.reshape([-1, *chandim])
                vmin = dat1.min(dim=0).values.tolist()
                vmax = dat1.max(dim=0).values.tolist()
                vmean = dat1.mean(dim=0).tolist()
            print(f'{title("min")} : {vmin}')
            print(f'{title("max")} : {vmax}')
            print(f'{title("mean")} : {vmean}')

    for key, value in metadata.items():
        if value is None and not meta:
            continue
        if torch.is_tensor(value):
            value = str(value.numpy())
            value = value.split('\n')
            value = ('\n' + ' ' * (pad + 3)).join(value)
        print(f'{title(key)} : {value}')
Ejemplo n.º 29
0
def _compute_cost(q,
                  grid0,
                  dat_fix,
                  mat_fix,
                  dat,
                  mat,
                  mov,
                  cost_fun,
                  B,
                  mx_int,
                  fwhm,
                  return_res=False):
    """Compute registration cost function.

    Parameters
    ----------
    q : (N, Nq) tensor_like
        Lie algebra of affine registration fit.
    grid0 : (X1, Y1, Z1) tensor_like
        Sub-sampled image data's resampling grid.
    dat_fix : (X1, Y1, Z1) tensor_like
        Fixed image data.
    mat_fix : (4, 4) tensor_like
        Fixed affine matrix.
    dat : [N,] tensor_like
        List of input images.
    mat : [N,] tensor_like
        List of affine matrices.
    mov : [N,] int
        Indices of moving images.
    cost_fun : str
        Cost function to compute (see run_affine_reg).
    B : (Nq, N, N) tensor_like
        Affine basis.
    mx_int : int
        This parameter sets the max intensity in the images, which decides
        how many bins to use in the joint image histograms
        (e.g, mx_int=511 -> H.shape = (512, 512)).
    fwhm : float
        Full-width at half max of Gaussian kernel, for smoothing
        histogram.
    return_res : bool, default=False
        Return registration results for plotting.

    Returns
    ----------
    c : float
        Cost of aligning images with current estimate of q. If
        optimiser='powell', array_like, else tensor_like.
    res : tensor_like
        Registration results, for visualisation (only if return_res=True).

    """
    # Init
    device = grid0.device
    q = q.flatten()
    was_numpy = False
    if isinstance(q, np.ndarray):
        was_numpy = True
        q = torch.from_numpy(q).to(device)  # To torch tensor
    dm_fix = dat_fix.shape
    Nq = B.shape[0]
    N = torch.tensor(len(dat), device=device,
                     dtype=torch.float32)  # For modulating NJTV cost

    if cost_fun in _costs_edge:
        jtv = dat_fix.clone()
        if cost_fun == 'njtv':
            njtv = -dat_fix.sqrt()

    for i, m in enumerate(mov):  # Loop over moving images
        # Get affine matrix
        mat_a = expm(q[torch.arange(i * Nq, i * Nq + Nq)], B)
        # Compose matrices
        M = mat_a.mm(mat_fix).solve(mat[m])[0].type(
            torch.float32)  # mat_mov\mat_a*mat_fix
        # Transform fixed grid
        grid = affine_matvec(M, grid0)
        # Resample to fixed grid
        dat_new = grid_pull(dat[m][None, None, ...],
                            grid[None, ...],
                            bound='dft',
                            extrapolate=False,
                            interpolation=1)[0, 0, ...]
        if cost_fun in _costs_edge:
            jtv += dat_new
            if cost_fun == 'njtv':
                njtv -= dat_new.sqrt()

    # Compute the cost function
    res = None
    if cost_fun in _costs_hist:
        # Histogram based costs
        # ----------
        # Compute joint histogram
        # OBS: This function expects both images to have the same max and min intesities,
        # this is ensured by the _data_loader() function.
        H = _hist_2d(dat_fix, dat_new, mx_int, fwhm)
        res = H

        # Get probabilities
        pxy = H / H.sum()
        px = pxy.sum(dim=0, keepdim=True)
        py = pxy.sum(dim=1, keepdim=True)

        # Compute cost
        if cost_fun == 'mi':
            # Mutual information
            mi = torch.sum(pxy * torch.log2(pxy / py.mm(px)))
            c = -mi
        elif cost_fun == 'ecc':
            # Entropy Correlation Coefficient
            # Maes, Collignon, Vandermeulen, Marchal & Suetens (1997).
            # "Multimodality image registration by maximisation of mutual
            # information". IEEE Transactions on Medical Imaging 16(2):187-198
            mi = torch.sum(pxy * torch.log2(pxy / py.mm(px)))
            ecc = -2 * mi / (torch.sum(px * px.log2()) +
                             torch.sum(py * py.log2()))
            c = -ecc
        elif cost_fun == 'nmi':
            # Normalised Mutual Information
            # Studholme,  Hill & Hawkes (1998).
            # "A normalized entropy measure of 3-D medical image alignment".
            # in Proc. Medical Imaging 1998, vol. 3338, San Diego, CA, pp. 132-143.
            nmi = (torch.sum(px * px.log2()) +
                   torch.sum(py * py.log2())) / torch.sum(pxy * pxy.log2())
            c = -nmi
        elif cost_fun == 'ncc':
            # Normalised Cross Correlation
            i = torch.arange(1,
                             pxy.shape[0] + 1,
                             device=device,
                             dtype=torch.float32)
            j = torch.arange(1,
                             pxy.shape[1] + 1,
                             device=device,
                             dtype=torch.float32)
            m1 = torch.sum(py * i[..., None])
            m2 = torch.sum(px * j[None, ...])
            sig1 = torch.sqrt(torch.sum(py[..., 0] * (i - m1)**2))
            sig2 = torch.sqrt(torch.sum(px[0, ...] * (j - m2)**2))
            i, j = torch.meshgrid(i - m1, j - m2)
            ncc = torch.sum(torch.sum(pxy * i * j)) / (sig1 * sig2)
            c = -ncc
    elif cost_fun in _costs_edge:
        # Normalised Joint Total Variation
        # M Brudfors, Y Balbastre, J Ashburner (2020).
        # "Groupwise Multimodal Image Registration using Joint Total Variation".
        # in MIUA 2020.
        jtv.sqrt_()
        if cost_fun == 'njtv':
            njtv += torch.sqrt(N) * jtv
            res = njtv
            c = torch.sum(njtv)
        else:
            res = jtv
            c = torch.sum(jtv)

    # _ = show_slices(res, fig_num=1, cmap='coolwarm')  # Can be uncommented for testing

    if was_numpy:
        # Back to numpy array
        c = c.cpu().numpy()

    if return_res:
        return c, res
    else:
        return c
Ejemplo n.º 30
0
def affine(source,
           target,
           group='SE',
           loss=None,
           pull=None,
           preproc=True,
           max_iter=1000,
           device=None,
           origin='center',
           init=None,
           lr=0.1,
           scheduler=ReduceLROnPlateau):
    """Affine registration
    
    Note
    ----
    .. Tensors must have shape (batch, channel, *spatial)
    .. Composite losses (e.g., computed on both intensity and categorical
       images) can be obtained by stacking all types of inputs across 
       the channel dimension. The loss function is then responsible 
       for unstacking the tensor and computing the appropriate 
       losses. The drawback of this approach is that all inputs
       must share the same lattice and orientation matrix, as 
       well as the same interpolation order. The advantage is that 
       it simplifies the signature of this function.

    Parameters
    ----------
    source : tensor or (tensor, affine)
    target : tensor or (tensor, affine)
    group : {'T', 'SO', 'SE', 'CSO', 'GL+', 'Aff+'}, default='SE'
    loss : Loss, default=MutualInfoLoss()
    pull : dict
        interpolation : int, default=1
        bound : bound_like, default='dct2'
        extrapolate : bool, default=False
    preproc : bool, default=True
    max_iter : int, default=1000
    device : device, optional
    origin : {'native', 'center'}, default='center'
    init : tensor_like, default=0
    lr : float, default=0.1
    scheduler : Scheduler, default=ReduceLROnPlateau

    Returns
    -------
    q : tensor
        Parameters
    aff : (D+1, D+1) tensor
        Affine transformation matrix.
        The source affine matrix can be "corrected" by left-multiplying
        it with `aff`.
    moved : tensor
        Source image moved to target space.


    """
    pull = pull or dict()
    pull['interpolation'] = pull.get('interpolation', 'linear')
    pull['bound'] = pull.get('bound', 'dct2')
    pull['extrapolate'] = pull.get('extrapolate', False)
    pull_opt = pull

    # prepare all data tensors
    ((source, source_aff), (target, target_aff)) = prepare([source, target],
                                                           device)
    backend = get_backend(source)
    batch = source.shape[0]
    src_channels = source.shape[1]
    trg_channels = target.shape[1]
    dim = source.dim() - 2

    # Rescale to [0, 1]
    if preproc:
        source = rescale(source)
        target = rescale(target)

    # Shift origin
    if origin == 'center':
        shift = torch.as_tensor(target.shape, **backend) / 2
        shift = -spatial.affine_matvec(target_aff, shift)
        target_aff[..., :-1, -1] += shift
        source_aff[..., :-1, -1] += shift

    # Prepare affine utils + Initialize parameters
    basis = spatial.affine_basis(group, dim, **backend)
    nb_prm = spatial.affine_basis_size(group, dim)
    if init is not None:
        parameters = torch.as_tensor(init, **backend).clone().detach()
        parameters = parameters.reshape([batch, nb_prm])
    else:
        parameters = torch.zeros([batch, nb_prm], **backend)
    parameters = nn.Parameter(parameters, requires_grad=True)
    identity = spatial.identity_grid(target.shape[2:], **backend)

    def pull(q):
        aff = core.linalg.expm(q, basis)
        aff = spatial.affine_matmul(aff, target_aff)
        aff = spatial.affine_lmdiv(source_aff, aff)
        expd = (slice(None), ) + (None, ) * dim + (slice(None), slice(None))
        grid = spatial.affine_matvec(aff[expd], identity)
        moved = spatial.grid_pull(source, grid, **pull_opt)
        return moved

    # Prepare loss and optimizer
    if loss is None:
        loss_fn = nni.MutualInfoLoss()
        loss = lambda x, y: loss_fn(x, y)

    optim = torch.optim.Adam([parameters], lr=lr)
    if scheduler is not None:
        scheduler = scheduler(optim)

    # Optim loop
    loss_val = core.constants.inf
    for n_iter in range(1, max_iter + 1):

        loss_val0 = loss_val
        optim.zero_grad(set_to_none=True)
        moved = pull(parameters)
        loss_val = loss(moved, target)
        loss_val.backward()
        optim.step()
        if scheduler is not None and n_iter % 10 == 0:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(loss_val)
            else:
                scheduler.step()

        with torch.no_grad():
            if n_iter % 10 == 0:
                print('{:4d} {:12.6f} | lr={:g}'.format(
                    n_iter, loss_val.item(), optim.param_groups[0]['lr']),
                      end='\r')

    print('')
    with torch.no_grad():
        moved = pull(parameters)
        aff = core.linalg.expm(parameters, basis)
        if origin == 'center':
            aff[..., :-1, -1] -= shift
            shift = core.linalg.matvec(aff[..., :-1, :-1], shift)
            aff[..., :-1, -1] += shift
        aff = aff.inverse()
        aff.requires_grad_(False)
    return parameters, aff, moved