Пример #1
0
 def __init__(self, dat, affine=None, dim=None, mask=None,
              bound='dct2', extrapolate=False, **backend):
     # I don't call super().__init__() on purpose
     if torch.is_tensor(affine):
         affine = [affine] * len(dat)
     elif affine is None:
         affine = []
         for dat1 in dat:
             dim1 = dim or dat1.dim
             if callable(dim1):
                 dim1 = dim1()
             if hasattr(dat1, 'affine'):
                 aff1 = dat1.affine
             else:
                 shape1 = dat1.shape[-dim1:]
                 aff1 = spatial.affine_default(shape1, **utils.backend(dat1))
             affine.append(aff1)
     affine = py.make_list(affine, len(dat))
     mask = py.make_list(mask, len(dat))
     self._dat = []
     for dat1, aff1, mask1 in zip(dat, affine, mask):
         if not isinstance(dat1, Image):
             dat1 = Image(dat1, aff1, mask=mask1, dim=dim,
                          bound=bound, extrapolate=extrapolate)
         self._dat.append(dat1)
Пример #2
0
    def _map_image(self):
        image = self.image
        self._map = None
        self._fdata = None
        self._affine = None
        self._shape = None
        if isinstance(image, str):
            self._map = io.map(image)
        else:
            self._map = None

        if self._map is None:
            if not isinstance(image, (list, tuple)):
                image = [image]
            if len(image) < 2:
                image = [*image, None, None]
            dat, aff, *_ = image
            dat = torch.as_tensor(dat)
            if aff is None:
                aff = spatial.affine_default(dat.shape[-3:])
            self._fdata = dat
            self._affine = aff
            self._shape = tuple(dat.shape[-3:])
        else:
            self._affine = self._map.affine
            self._shape = tuple(self._map.shape[-3:])
Пример #3
0
def _get_default_native(affines, shapes):
    """Get default native space

    Parameters
    ----------
    affines : [sequence of] (4, 4) tensor_like or None
    shapes : [sequence of] (3,) tensor_like

    Returns
    -------
    affines : (N, 4, 4) tensor
    shapes : (N, 3) tensor

    """
    shapes = utils.as_tensor(shapes).reshape([-1, 3])
    shapes = shapes.unbind(dim=0)
    if torch.is_tensor(affines):
        affines = affines.reshape([-1, 4, 4])
        affines = affines.unbind(dim=0)
    shapes = py.make_list(shapes, max(len(shapes), len(affines)))
    affines = py.make_list(affines, max(len(shapes), len(affines)))

    # default affines
    affines = [spatial.affine_default(shape) if affine is None else affine
               for shape, affine in zip(shapes, affines)]

    affines = utils.as_tensor(affines)
    shapes = utils.as_tensor(shapes)
    affines, shapes = utils.to_max_device(affines, shapes)
    return affines, shapes
Пример #4
0
def _reset_origin(dat, mat, interpolation):
    """Reset affine matrix.

    Parameters
    ----------
    dat : (X0, Y0, Z0) tensor_like, dtype=float32
        Image data.
    mat : (4, 4) tensor_like, dtype=float64
        Affine matrix.
    interpolation : int, default=1 (linear)
        Interpolation order.

    Returns
    -------
    dat : (X1, Y1, Z1) tensor_like, dtype=float32
        New image data.
    mat : (4, 4) tensor_like, dtype=float64
        New affine matrix.

    """
    device = dat.device
    # Reslice image data to world FOV
    dat, mat = _world_reslice(dat, mat, interpolation=interpolation)
    # Compute new, reset, affine matrix
    vx = voxel_size(mat)
    if mat[:3, :3].det() < 0:
        vx[0] = -vx[0]
    vx = vx.tolist()
    mat = affine_default(dat.shape, vx, dtype=torch.float64, device=device)

    return dat, mat
Пример #5
0
 def _affine(self):
     """Affine orientation matrix of a series+level"""
     # TODO: I don't know yet how we should use GeoTiff to encode
     #   affine matrices. In the matrix/zooms, their voxels are ordered
     #   as [x, y, z] even though their dimensions in the returned array
     #   are ordered as [Z, Y, X]. If we want to keep the same convention
     #   as nitorch, I need to permute the matrix/zooms.
     if '_affine' not in self._cache:
         with self.tiffobj() as tiff:
             omexml = tiff.ome_metadata
             geotags = tiff.geotiff_metadata or {}
         zooms, units, axes = ome_zooms(omexml, self.series)
         if zooms:
             # convert to mm + drop non-spatial zooms
             units = [parse_unit(u) for u in units]
             zooms = [z * (f / 1e-3) for z, (f, type) in zip(zooms, units)
                      if type in ('m', 'pixel')]
             if 'ModelPixelScaleTag' in geotags:
                 warn("Both OME and GeoTiff pixel scales are present: "
                      "{} vs {}. Using OME."
                      .format(zooms, geotags['ModelPixelScaleTag']))
         elif 'ModelPixelScaleTag' in geotags:
             zooms = geotags['ModelPixelScaleTag']
             axes = 'XYZ'
         else:
             zooms = 1.
             axes = [ax for ax in self._axes if ax in 'XYZ']
         if 'ModelTransformation' in geotags:
             aff = geotags['ModelTransformation']
             aff = torch.as_tensor(aff, dtype=torch.double).reshape(4, 4)
             self._cache['_affine'] = aff
         elif ('ModelTiepointTag' in geotags):
             # copied from tifffile
             sx, sy, sz = py.make_list(zooms, n=3)
             tiepoints = torch.as_tensor(geotags['ModelTiepointTag'])
             affines = []
             for tiepoint in tiepoints:
                 i, j, k, x, y, z = tiepoint
                 affines.append(torch.as_tensor(
                     [[sx,  0.0, 0.0, x - i * sx],
                      [0.0, -sy, 0.0, y + j * sy],
                      [0.0, 0.0, sz,  z - k * sz],
                      [0.0, 0.0, 0.0, 1.0]], dtype=torch.double))
             affines = torch.stack(affines, dim=0)
             if len(tiepoints) == 1:
                 affines = affines[0]
                 self._cache['_affine'] = affines
         else:
             zooms = py.make_list(zooms, n=len(axes))
             ax2zoom = {ax: zoom for ax, zoom in zip(axes, zooms)}
             axes = [ax for ax in self._axes if ax in 'XYZ']
             shape = [shp for shp, msk in zip(self._shape, self._spatial)
                      if msk]
             zooms = [ax2zoom.get(ax, 1.) for ax in axes]
             layout = [('R' if ax == 'Z' else 'P' if ax == 'Y' else 'S')
                       for ax in axes]
             aff = affine_default(shape, zooms, layout=''.join(layout))
             self._cache['_affine'] = aff
     return self._cache['_affine']
Пример #6
0
    def __init__(self,
                 moving,
                 fixed,
                 loss,
                 basis='CSO',
                 dim=None,
                 affine_moving=None,
                 affine_fixed=None,
                 verbose=True,
                 plot=False,
                 max_iter=100,
                 bound='dct2',
                 extrapolate=True,
                 **prm):
        if dim is None:
            if affine_fixed is not None:
                dim = affine_fixed.shape[-1] - 1
            elif affine_moving is not None:
                dim = affine_moving.shape[-1] - 1
        dim = dim or fixed.dim() - 1
        self.dim = dim
        self.moving = moving  # moving image
        self.fixed = fixed  # fixed image
        self.loss = loss  # similarity loss (`OptimizationLoss`)
        self.verbose = verbose  # print stuff
        self.plot = plot  # plot stuff
        self.prm = prm  # dict of regularization parameters
        self.bound = bound
        self.extrapolate = extrapolate
        self.basis = basis
        if affine_fixed is None:
            affine_fixed = spatial.affine_default(fixed.shape[-dim:],
                                                  **utils.backend(fixed))
        if affine_moving is None:
            affine_moving = spatial.affine_default(moving.shape[-dim:],
                                                   **utils.backend(moving))
        self.affine_fixed = affine_fixed
        self.affine_moving = affine_moving

        # pretty printing
        self.max_iter = max_iter  # max number of iterations
        self.n_iter = 0  # current iteration
        self.ll_prev = None  # previous loss value
        self.ll_max = 0  # max loss value
        self.id = None
Пример #7
0
 def reset_attributes(self):
     """Reset attributes to their default values (when they have one)."""
     if isinstance(self.volume, io.MappedArray):
         self.affine = self.volume.affine
         if isinstance(self.volume.affine, (tuple, list)):
             self.affine = self.affine[0]
     elif torch.is_tensor(self.volume):
         self.affine = affine_default(self.volume.shape[:3])
     return self
Пример #8
0
 def prepare_one(inp):
     if isinstance(inp, (list, tuple)):
         has_aff = len(inp) > 1
         if has_aff:
             aff0 = inp[1]
         inp, aff = prepare_one(inp[0])
         if has_aff:
             aff = aff0
         return [inp, aff]
     if isinstance(inp, str):
         inp = io.map(inp)[None, None]
     if isinstance(inp, io.MappedArray):
         return inp.fdata(rand=True), inp.affine[None]
     inp = torch.as_tensor(inp)
     aff = spatial.affine_default(inp.shape)[None]
     return [inp, aff]
Пример #9
0
 def __init__(self, dat, affine=None, dim=None, **backend):
     """
     Parameters
     ----------
     dat : ([C], *spatial) tensor
     affine : tensor, optional
     dim : int, default=`dat.dim() - 1`
     **backend : dtype, device
     """
     if isinstance(dat, str):
         dat = io.map(dat)[None]
     if isinstance(dat, io.MappedArray):
         if affine is None:
             affine = dat.affine
         dat = dat.fdata(rand=True, **backend)
     self.dim = dim or dat.dim() - 1
     self.dat = dat
     if affine is None:
         affine = spatial.affine_default(self.shape, **utils.backend(dat))
     self.affine = affine.to(utils.backend(self.dat)['device'])
Пример #10
0
    def __init__(self, dat, levels=1, affine=None, dim=None,
                 mask=None, preview=None,
                 bound='dct2', extrapolate=False, method='gauss', **backend):
        """

        Parameters
        ----------
        dat : [list of] (..., *shape) tensor or Image
        levels : int or list[int] or range, default=0
            If an int, it is the number of levels.
            If a range or list, they are the indices of levels to compute.
            `0` is the native resolution, `1` is half of it, etc.
        affine : [list of] tensor, optional
        dim : int, optional
        bound : str, default='dct2'
        extrapolate : bool, default=True
        method : {'gauss', 'average', 'median', 'stride'}, default='gauss'
        """
        # I don't call super().__init__() on purpose
        self.method = method
        
        if isinstance(dat, Image):
            if affine is None:
                affine = dat.affine
            dim = dat.dim
            mask = dat.mask
            preview = dat._preview
            bound = dat.bound
            extrapolate = dat.extrapolate
            dat = dat.dat
        if isinstance(dat, str):
            dat = io.map(dat)
        if isinstance(dat, io.MappedArray):
            if affine is None:
                affine = dat.affine
            dat = dat.fdata(rand=True, **backend)[None]

        if isinstance(levels, int):
            levels = range(levels)
        if isinstance(levels, range):
            levels = list(levels)

        if not isinstance(dat, (list, tuple)):
            dim = dim or dat.dim() - 1
            if not levels:
                raise ValueError('levels required to compute pyramid')
            if isinstance(levels, int):
                levels = range(1, levels+1)
            if affine is None:
                shape = dat.shape[-dim:]
                affine = spatial.affine_default(shape, **utils.backend(dat[0]))
            dat, mask, preview = self._build_pyramid(
                dat, levels, method, dim, bound, mask, preview)
        dat = list(dat)
        if not mask:
            mask = [None] * len(dat)
        if not preview:
            preview = [None] * len(dat)

        if all(isinstance(d, Image) for d in dat):
            self._dat = dat
            return

        dim = dim or (dat[0].dim() - 1)
        if affine is None:
            shape = dat[0].shape[-dim:]
            affine = spatial.affine_default(shape, **utils.backend(dat[0]))
        if not isinstance(affine, (list, tuple)):
            if not levels:
                raise ValueError('levels required to compute affine pyramid')
            shape = dat[0].shape[-dim:]
            affine = self._build_affine_pyramid(affine, shape, levels, method)
        self._dat = [Image(d, aff, dim=dim,  mask=m, preview=p,
                           bound=bound, extrapolate=extrapolate)
                     for d, m, p, aff in zip(dat, mask, preview, affine)]
Пример #11
0
def align_tpm(dat,
              tpm=None,
              weights=None,
              spacing=(8, 4),
              device=None,
              basis='affine',
              joint=False,
              progressive=False,
              bins=256,
              fwhm=None,
              max_iter_gn=100,
              max_iter_em=32,
              max_line_search=6,
              verbose=1):
    """Align a Tissue Probability Map to an image

    Input Parameters
    ----------------
    dat : file(s) or tensor or (tensor, affine)
        Input image(s)
    tpm : file(s) or tensor or (tensor, affine), optional
        Input tissue probability map. Uses SPM's TPM by default.
    weights : file(s) or tensor
        Input mask or weight map
    device : torch.device, optional
        Specify device
    verbose : int, default=1
        0 = Write nothing
        1 = Outer loop
        2 = Line search
        3 = Inner loop

    Option Parameters
    -----------------
    spacing : float(s), default=(8, 3)
        Sampling  distance in mm. If multiple value, coarse to fine fit.
        Larger is faster but less accurate.
    basis : {'trans', 'rot', 'rigid', 'sim', 'aff'}, default='affine'
        Transformation model
    joint : bool, default=False
        Estimate a single affine for all images
    progressive : bool, default=False
        Fit prameters progressively (translation then rigid then affine)

    Optimization parameters
    -----------------------
    bins : int, default=256
        Number of bins to use to discretize the input image
    fwhm : float, default=bins/64
        Full-width at half-maximum used to smooth the joint histogram
    max_iter_gn : int, default=100
        Maximumm number of Gauss-Newton iterations
    max_iter_em : int, default=32
        Maximum number of EM iterations
    max_line_search : int, default=6
        Maximum number of line search steps

    Returns
    -------
    aff : ([B], 4, 4) tensor
        Affine matrix.
        Can be applied to the TPM by `aff \ tpm.affine`
        or to the image by `aff @ dat.affine`.

    """
    # ------------------------------------------------------------------
    #       LOAD DATA
    # ------------------------------------------------------------------
    affine_dat = affine_tpm = None
    if isinstance(dat, (list, tuple)) and torch.is_tensor(dat[0]):
        affine_dat = dat[1] if len(dat) > 1 else None
        dat = dat[0]
    if isinstance(tpm, (list, tuple)) and torch.is_tensor(tpm[0]):
        affine_tpm = tpm[1] if len(tpm) > 1 else None
        tpm = tpm[0]
    backend = get_backend(dat, tpm, device)
    tpm, affine_tpm = get_prior(tpm, affine_tpm, **backend)
    dim = tpm.dim() - 1
    dat, weights, affine_dat = get_data(dat, weights, affine_dat, dim,
                                        **backend)

    if weights is None:
        weights = 1
    weights = weights * torch.isfinite(dat)

    # ------------------------------------------------------------------
    #       DEFAULT ORIENTATION MATRICES
    # ------------------------------------------------------------------
    if affine_tpm is not None:
        affine_tpm = affine_tpm.to(dat.dtype)
    else:
        affine_tpm = spatial.affine_default(tpm.shape[-dim:], **backend)
    if affine_dat is None:
        affine_dat = spatial.affine_default(dat.shape[-dim:], **backend)

    dat = dat.unsqueeze(1)  # [B, 1, *spatial]
    weights = weights.unsqueeze(1)  # [B, 1, *spatial]
    tpm = tpm.unsqueeze(0)  # [1, K, *spatial]

    # ------------------------------------------------------------------
    #       DISCRETIZE
    # ------------------------------------------------------------------
    dat = discretize(dat, nbins=bins, mask=weights)

    # ------------------------------------------------------------------
    #       OPTIONS
    # ------------------------------------------------------------------
    opt = dict(
        basis=basis,
        joint=joint,
        progressive=progressive,
        fwhm=fwhm,
        max_iter_gn=max_iter_gn,
        max_iter_em=max_iter_em,
        max_line_search=max_line_search,
        verbose=verbose,
    )

    # ------------------------------------------------------------------
    #       SPACING
    # ------------------------------------------------------------------
    spacing = py.make_list(spacing) or [0]
    dat0, affine_dat0, weights0 = dat, affine_dat, weights
    vx = spatial.voxel_size(affine_dat0).tolist()
    prm = None
    for sp in spacing:

        if sp:
            sp = [max(1, int(pymath.floor(sp / vx1))) for vx1 in vx]
            sp = [slice(None, None, sp1) for sp1 in sp]
            affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:],
                                               tuple(sp))
            dat = dat0[(Ellipsis, *sp)]
            if weights is not None:
                weights = weights0[(Ellipsis, *sp)]

        _, aff, prm = fit_affine_tpm(dat,
                                     tpm,
                                     affine_dat,
                                     affine_tpm,
                                     weights,
                                     **opt,
                                     prm=prm)

    return aff.squeeze()
Пример #12
0
def fit_affine_tpm(dat,
                   tpm,
                   affine=None,
                   affine_tpm=None,
                   weights=None,
                   basis='affine',
                   fwhm=None,
                   joint=False,
                   prm=None,
                   max_iter_gn=100,
                   max_iter_em=32,
                   max_line_search=6,
                   progressive=False,
                   verbose=1):
    """

    Parameters
    ----------
    dat : (B, J|1, *spatial) tensor
    tpm : (B|1, K, *spatial) tensor
    affine : (4, 4) tensor
    affine_tpm : (4, 4) tensor
    weights : (B, 1, *spatial) tensor
    basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'}
    fwhm : float, default=J/32
    joint : bool, default=False
    max_iter_gn : int, default=100
    max_iter_em : int, default=32
    max_line_search : int, default=12
    progressive : bool, default=False

    Returns
    -------
    mi : (B,) tensor
    aff : (B, 4, 4) tensor
    prm : (B, F) tensor

    """
    dim = dat.dim() - 2

    # ------------------------------------------------------------------
    #       RECURSIVE PROGRESSIVE FIT
    # ------------------------------------------------------------------
    if progressive:
        nb_se = dim * (dim + 1) // 2
        nb_aff = dim * (dim + 1)
        basis_recursion = {'Aff+': 'CSO', 'CSO': 'SE', 'SE': 'T'}
        basis_nb_feat = {'Aff+': nb_aff, 'CSO': nb_se + 1, 'SE': nb_se}
        basis = convert_basis(basis)
        next_basis = basis_recursion.get(basis, None)
        if next_basis:
            *_, prm = fit_affine_tpm(dat,
                                     tpm,
                                     affine,
                                     affine_tpm,
                                     weights,
                                     basis=next_basis,
                                     fwhm=fwhm,
                                     joint=joint,
                                     prm=prm,
                                     max_iter_gn=max_iter_gn,
                                     max_iter_em=max_iter_em,
                                     max_line_search=max_line_search)
            B = len(dat)
            F = basis_nb_feat[basis]
            prm0 = prm
            prm = prm0.new_zeros([1 if joint else B, F])
            if basis == 'SE':
                prm[:, :dim] = prm0[:, :dim]
            else:
                nb_se = dim * (dim + 1) // 2
                prm[:, :nb_se] = prm0[:, :nb_se]
                if basis == 'Aff+':
                    prm[:, nb_se:nb_se + dim] = prm0[:, nb_se] * (dim**(-0.5))

    basis_name = basis

    # ------------------------------------------------------------------
    #       PREPARE
    # ------------------------------------------------------------------

    B = len(dat)
    if affine is None:
        affine = spatial.affine_default(dat.shape[-dim:])
    if affine_tpm is None:
        affine_tpm = spatial.affine_default(tpm.shape[-dim:])
    affine = affine.to(**utils.backend(tpm))
    affine_tpm = affine_tpm.to(**utils.backend(tpm))
    shape = dat.shape[-dim:]

    tpm = tpm.to(dat.device)
    basis = make_basis(basis, dim, **utils.backend(tpm))
    F = len(basis)

    if prm is None:
        prm = tpm.new_zeros([1 if joint else B, F])
    aff, gaff = linalg._expm(prm, basis, grad_X=True)

    em_opt = dict(fwhm=fwhm,
                  max_iter=max_iter_em,
                  weights=weights,
                  verbose=verbose - 2)
    drv_opt = dict(weights=weights)
    pull_opt = dict(bound='replicate', extrapolate=True)

    # ------------------------------------------------------------------
    #       OPTIMIZE
    # ------------------------------------------------------------------
    prior = None
    mi = torch.as_tensor(-float('inf'))
    delta = torch.zeros_like(prm)
    for n_iter in range(max_iter_gn):

        # --------------------------------------------------------------
        #       LINE SEARCH
        # --------------------------------------------------------------
        prior0, prm0, mi0 = prior, prm, mi
        armijo = 1
        success = False
        for n_ls in range(max_line_search):

            # --- take a step ------------------------------------------
            prm = prm0 - armijo * delta

            # --- build transformation field ---------------------------
            aff, gaff = linalg._expm(prm, basis, grad_X=True)
            phi = lmdiv(affine_tpm, mm(aff, affine))
            phi = spatial.affine_grid(phi, shape)

            # --- warp TPM ---------------------------------------------
            mov = spatial.grid_pull(tpm, phi, **pull_opt)

            # --- mutual info ------------------------------------------
            mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt)
            mi = mi / Nm

            success = mi.sum() > mi0.sum()
            if verbose >= 2:
                end = '\n' if verbose >= 3 else '\r'
                happy = ':D' if success else ':('
                print(f'(search) | {n_ls:02d} | {mi.mean():12.6g} | {happy}',
                      end=end)
            if success:
                break
            armijo *= 0.5
        # if verbose == 2:
        #     print('')

        # --------------------------------------------------------------
        #       DID IT WORK?
        # --------------------------------------------------------------

        if not success:
            prior, prm, mi = prior0, prm0, mi0
            break

        # DEBUG
        # plot_registration(dat, mov, f'{basis_name} | {n_iter}')

        space = ' ' * max(0, 6 - len(basis_name))
        if verbose >= 1:
            end = '\n' if verbose >= 2 else '\r'
            print(
                f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}',
                end=end)

        if mi.mean() - mi0.mean() < 1e-5:
            break

        # --------------------------------------------------------------
        #       GAUSS-NEWTON
        # --------------------------------------------------------------

        # --- derivatives ----------------------------------------------
        g, h = derivatives_intensity(mov, dat, prior, **drv_opt)

        # --- chain rule -----------------------------------------------
        gmov = spatial.grid_grad(tpm, phi, **pull_opt)
        if joint and len(mov) == 1:
            g = g.sum(0, keepdim=True)
            h = h.sum(0, keepdim=True)
        else:
            gmov = gmov.expand([B, *gmov.shape[1:]])
        gaff = lmdiv(affine_tpm, mm(gaff, affine))
        g, h = chain_rule(g, h, gmov, gaff, maj=False)
        del gmov

        if joint and len(g) > 1:
            g = g.sum(0, keepdim=True)
            h = h.sum(0, keepdim=True)

        # --- Gauss-Newton ---------------------------------------------
        delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1)

    if verbose == 1:
        print('')
    return mi, aff, prm
Пример #13
0
def vexp(inp,
         type='displacement',
         unit='voxel',
         inverse=False,
         bound='dft',
         steps=8,
         device=None,
         output=None):
    """Exponentiate a stationary velocity fields.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    type : {'displacement', 'transformation'}, default='displacement'
        Whether to return a displacement field (coord-to-shift) or a
        transformation field (coord-to-coord).
    unit : {'voxel', 'mm'}, default='voxel'
        Whether to return displacement/coordinates in voxel or in mm.
        If mm, the input orientation matrix is used to convert voxels to mm.
    inverse : bool, default=False
        Whether to return the inverse field.
    bound : str, default='dft'
        Boundary conditions.
    steps : int, default=8
        Number of scaling and squaring steps.
    device : str, optional
        Device to use.
    output : str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.vexp{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.

    Returns
    -------
    output : str or (tensor, tensor)
        If the input is a path, the output path is returned.
        Else, the output tensor and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(device=device), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.vexp{ext}'
        dir, base, ext = py.fileparts(fname)
    else:
        if torch.is_tensor(inp):
            inp = (inp.clone(), spatial.affine_default(shape=inp.shape[:3]))
    dat, aff = inp
    dat = dat.to(device=device)
    aff = aff.to(device=device)

    # exponentiate
    dat = spatial.exp(dat[None],
                      inverse=inverse,
                      steps=steps,
                      bound=bound,
                      inplace=True,
                      displacement=(type.lower()[0] == 'd'))[0]
    if unit == 'mm':
        # if type.lower()[0] == 'd':
        #     vx = spatial.voxel_size(aff)
        #     dat *= vx
        # else:
        dat = spatial.affine_matvec(aff, dat)

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff.cpu())
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff.cpu())

    if is_file:
        return output
    else:
        return dat, aff
Пример #14
0
 def affine(self, value):
     if value is None:
         self._affine = spatial.affine_default(self.shape)
     self._affine = value
Пример #15
0
def get_orthogonal_oriented_slices(image, index=None, affine=None,
                                   space=None, bbox=None, interpolation=1,
                                   transpose_sagittal=False, return_index=False,
                                   return_mat=False):
    """Sample orthogonal slices in a RAS system

    Parameters
    ----------
    image : (..., *shape3)
        Input image
    index : (3,) sequence or tensor, default=shape//2
        Coordinate (in voxel) of the slice to extract
    affine : (4, 4) tensor, optional
        Orientation matrix of the image
    space : (4, 4) tensor, optional
        Orientation matrix of the visualisation space.
        Default: RAS with minimum voxel size of all inputs.
    bbox : (2, D) tensor_like, optional
        Bounding box: min and max coordinates (in millimetric
        visualisation space). Default: bounding box of the input image.
    interpolation : {0, 1}, default=1
        Interpolation order.

    Returns
    -------
    slice : tuple of (..., *shape2) tensor
        Slices in the visualisation space.

    """
    shape = image.shape[-3:]
    if affine is None:
        affine = spatial.affine_default(shape)
    affine = torch.as_tensor(affine)

    # compute default space (mn/mx are in voxels)
    affines = affine.reshape([-1, 4, 4])
    shapes = [shape] * len(affines)
    space, mn, mx = _get_default_space(affines, shapes, space, bbox)
    voxel_size = spatial.voxel_size(space)
    mn *= voxel_size
    mx *= voxel_size

    prm = {
        'index': index,
        'affine': affine,
        'space': space,
        'bbox': [mn, mx],
        'interpolation': interpolation,
        'transpose_sagittal': transpose_sagittal,
        'return_index': return_index,
        'return_mat': return_mat
    }
    slices = tuple(get_oriented_slice(image, dim=d, **prm) for d in (-1, -2, -3))
    if return_index and return_mat:
        return (tuple(sl[0] for sl in slices),
                tuple(sl[1] for sl in slices),
                tuple(sl[2] for sl in slices))
    elif return_index or return_mat:
        return (tuple(sl[0] for sl in slices),
                tuple(sl[1] for sl in slices))
    else:
        return slices
Пример #16
0
def orient(inp,
           layout=None,
           voxel_size=None,
           center=None,
           like=None,
           output=None):
    """Overwrite the orientation matrix

    Parameters
    ----------
    inp : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    layout : str or layout-like, default=None (= preserve)
        Target orientation.
    voxel_size : [sequence of] float, default=None (= preserve)
        Target voxel size.
    center : [sequence of] float, default=None (= preserve)
        World coordinate of the center of the field of view.
    like : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    output : str, optional
        Output filename.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.

    Returns
    -------
    output : str or (tuple, tensor)
        If the input is a path, the output path is returned.
        Else, the new shape and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        dim = f.affine.shape[-1] - 1
        inp = (f.shape[:dim], f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{layout}{ext}'
        dir, base, ext = py.fileparts(fname)

    like_is_file = isinstance(like, str) and like
    if like_is_file:
        f = io.volumes.map(like)
        dim = f.affine.shape[-1] - 1
        like = (f.shape[:dim], f.affine)

    shape, aff0 = inp
    dim = aff0.shape[-1] - 1
    if like:
        shape_like, aff_like = like
    else:
        shape_like, aff_like = (shape, aff0)

    if voxel_size in (None, 'like') or len(voxel_size) == 0:
        voxel_size = spatial.voxel_size(aff_like)
    elif voxel_size == 'self':
        voxel_size = spatial.voxel_size(aff0)
    voxel_size = utils.make_vector(voxel_size, dim)

    if not layout or layout == 'like':
        layout = spatial.affine_to_layout(aff_like)
    elif layout == 'self':
        layout = spatial.affine_to_layout(aff0)
    layout = spatial.volume_layout(layout)

    if center in (None, 'like') or len(voxel_size) == 0:
        center = torch.as_tensor(shape_like, dtype=torch.float) * 0.5
        center = spatial.affine_matvec(aff_like, center)
    elif center == 'self':
        center = torch.as_tensor(shape, dtype=torch.float) * 0.5
        center = spatial.affine_matvec(aff0, center)

    center = utils.make_vector(center, dim)

    aff = spatial.affine_default(shape,
                                 voxel_size=voxel_size,
                                 layout=layout,
                                 center=center,
                                 dtype=torch.double)

    if output:
        dat = io.volumes.load(fname, numpy=True)
        layout = spatial.volume_layout_to_name(layout)
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   layout=layout)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, affine=aff)

    if is_file:
        return output
    else:
        return shape, aff