def __init__(self, dat, affine=None, dim=None, mask=None, bound='dct2', extrapolate=False, **backend): # I don't call super().__init__() on purpose if torch.is_tensor(affine): affine = [affine] * len(dat) elif affine is None: affine = [] for dat1 in dat: dim1 = dim or dat1.dim if callable(dim1): dim1 = dim1() if hasattr(dat1, 'affine'): aff1 = dat1.affine else: shape1 = dat1.shape[-dim1:] aff1 = spatial.affine_default(shape1, **utils.backend(dat1)) affine.append(aff1) affine = py.make_list(affine, len(dat)) mask = py.make_list(mask, len(dat)) self._dat = [] for dat1, aff1, mask1 in zip(dat, affine, mask): if not isinstance(dat1, Image): dat1 = Image(dat1, aff1, mask=mask1, dim=dim, bound=bound, extrapolate=extrapolate) self._dat.append(dat1)
def _map_image(self): image = self.image self._map = None self._fdata = None self._affine = None self._shape = None if isinstance(image, str): self._map = io.map(image) else: self._map = None if self._map is None: if not isinstance(image, (list, tuple)): image = [image] if len(image) < 2: image = [*image, None, None] dat, aff, *_ = image dat = torch.as_tensor(dat) if aff is None: aff = spatial.affine_default(dat.shape[-3:]) self._fdata = dat self._affine = aff self._shape = tuple(dat.shape[-3:]) else: self._affine = self._map.affine self._shape = tuple(self._map.shape[-3:])
def _get_default_native(affines, shapes): """Get default native space Parameters ---------- affines : [sequence of] (4, 4) tensor_like or None shapes : [sequence of] (3,) tensor_like Returns ------- affines : (N, 4, 4) tensor shapes : (N, 3) tensor """ shapes = utils.as_tensor(shapes).reshape([-1, 3]) shapes = shapes.unbind(dim=0) if torch.is_tensor(affines): affines = affines.reshape([-1, 4, 4]) affines = affines.unbind(dim=0) shapes = py.make_list(shapes, max(len(shapes), len(affines))) affines = py.make_list(affines, max(len(shapes), len(affines))) # default affines affines = [spatial.affine_default(shape) if affine is None else affine for shape, affine in zip(shapes, affines)] affines = utils.as_tensor(affines) shapes = utils.as_tensor(shapes) affines, shapes = utils.to_max_device(affines, shapes) return affines, shapes
def _reset_origin(dat, mat, interpolation): """Reset affine matrix. Parameters ---------- dat : (X0, Y0, Z0) tensor_like, dtype=float32 Image data. mat : (4, 4) tensor_like, dtype=float64 Affine matrix. interpolation : int, default=1 (linear) Interpolation order. Returns ------- dat : (X1, Y1, Z1) tensor_like, dtype=float32 New image data. mat : (4, 4) tensor_like, dtype=float64 New affine matrix. """ device = dat.device # Reslice image data to world FOV dat, mat = _world_reslice(dat, mat, interpolation=interpolation) # Compute new, reset, affine matrix vx = voxel_size(mat) if mat[:3, :3].det() < 0: vx[0] = -vx[0] vx = vx.tolist() mat = affine_default(dat.shape, vx, dtype=torch.float64, device=device) return dat, mat
def _affine(self): """Affine orientation matrix of a series+level""" # TODO: I don't know yet how we should use GeoTiff to encode # affine matrices. In the matrix/zooms, their voxels are ordered # as [x, y, z] even though their dimensions in the returned array # are ordered as [Z, Y, X]. If we want to keep the same convention # as nitorch, I need to permute the matrix/zooms. if '_affine' not in self._cache: with self.tiffobj() as tiff: omexml = tiff.ome_metadata geotags = tiff.geotiff_metadata or {} zooms, units, axes = ome_zooms(omexml, self.series) if zooms: # convert to mm + drop non-spatial zooms units = [parse_unit(u) for u in units] zooms = [z * (f / 1e-3) for z, (f, type) in zip(zooms, units) if type in ('m', 'pixel')] if 'ModelPixelScaleTag' in geotags: warn("Both OME and GeoTiff pixel scales are present: " "{} vs {}. Using OME." .format(zooms, geotags['ModelPixelScaleTag'])) elif 'ModelPixelScaleTag' in geotags: zooms = geotags['ModelPixelScaleTag'] axes = 'XYZ' else: zooms = 1. axes = [ax for ax in self._axes if ax in 'XYZ'] if 'ModelTransformation' in geotags: aff = geotags['ModelTransformation'] aff = torch.as_tensor(aff, dtype=torch.double).reshape(4, 4) self._cache['_affine'] = aff elif ('ModelTiepointTag' in geotags): # copied from tifffile sx, sy, sz = py.make_list(zooms, n=3) tiepoints = torch.as_tensor(geotags['ModelTiepointTag']) affines = [] for tiepoint in tiepoints: i, j, k, x, y, z = tiepoint affines.append(torch.as_tensor( [[sx, 0.0, 0.0, x - i * sx], [0.0, -sy, 0.0, y + j * sy], [0.0, 0.0, sz, z - k * sz], [0.0, 0.0, 0.0, 1.0]], dtype=torch.double)) affines = torch.stack(affines, dim=0) if len(tiepoints) == 1: affines = affines[0] self._cache['_affine'] = affines else: zooms = py.make_list(zooms, n=len(axes)) ax2zoom = {ax: zoom for ax, zoom in zip(axes, zooms)} axes = [ax for ax in self._axes if ax in 'XYZ'] shape = [shp for shp, msk in zip(self._shape, self._spatial) if msk] zooms = [ax2zoom.get(ax, 1.) for ax in axes] layout = [('R' if ax == 'Z' else 'P' if ax == 'Y' else 'S') for ax in axes] aff = affine_default(shape, zooms, layout=''.join(layout)) self._cache['_affine'] = aff return self._cache['_affine']
def __init__(self, moving, fixed, loss, basis='CSO', dim=None, affine_moving=None, affine_fixed=None, verbose=True, plot=False, max_iter=100, bound='dct2', extrapolate=True, **prm): if dim is None: if affine_fixed is not None: dim = affine_fixed.shape[-1] - 1 elif affine_moving is not None: dim = affine_moving.shape[-1] - 1 dim = dim or fixed.dim() - 1 self.dim = dim self.moving = moving # moving image self.fixed = fixed # fixed image self.loss = loss # similarity loss (`OptimizationLoss`) self.verbose = verbose # print stuff self.plot = plot # plot stuff self.prm = prm # dict of regularization parameters self.bound = bound self.extrapolate = extrapolate self.basis = basis if affine_fixed is None: affine_fixed = spatial.affine_default(fixed.shape[-dim:], **utils.backend(fixed)) if affine_moving is None: affine_moving = spatial.affine_default(moving.shape[-dim:], **utils.backend(moving)) self.affine_fixed = affine_fixed self.affine_moving = affine_moving # pretty printing self.max_iter = max_iter # max number of iterations self.n_iter = 0 # current iteration self.ll_prev = None # previous loss value self.ll_max = 0 # max loss value self.id = None
def reset_attributes(self): """Reset attributes to their default values (when they have one).""" if isinstance(self.volume, io.MappedArray): self.affine = self.volume.affine if isinstance(self.volume.affine, (tuple, list)): self.affine = self.affine[0] elif torch.is_tensor(self.volume): self.affine = affine_default(self.volume.shape[:3]) return self
def prepare_one(inp): if isinstance(inp, (list, tuple)): has_aff = len(inp) > 1 if has_aff: aff0 = inp[1] inp, aff = prepare_one(inp[0]) if has_aff: aff = aff0 return [inp, aff] if isinstance(inp, str): inp = io.map(inp)[None, None] if isinstance(inp, io.MappedArray): return inp.fdata(rand=True), inp.affine[None] inp = torch.as_tensor(inp) aff = spatial.affine_default(inp.shape)[None] return [inp, aff]
def __init__(self, dat, affine=None, dim=None, **backend): """ Parameters ---------- dat : ([C], *spatial) tensor affine : tensor, optional dim : int, default=`dat.dim() - 1` **backend : dtype, device """ if isinstance(dat, str): dat = io.map(dat)[None] if isinstance(dat, io.MappedArray): if affine is None: affine = dat.affine dat = dat.fdata(rand=True, **backend) self.dim = dim or dat.dim() - 1 self.dat = dat if affine is None: affine = spatial.affine_default(self.shape, **utils.backend(dat)) self.affine = affine.to(utils.backend(self.dat)['device'])
def __init__(self, dat, levels=1, affine=None, dim=None, mask=None, preview=None, bound='dct2', extrapolate=False, method='gauss', **backend): """ Parameters ---------- dat : [list of] (..., *shape) tensor or Image levels : int or list[int] or range, default=0 If an int, it is the number of levels. If a range or list, they are the indices of levels to compute. `0` is the native resolution, `1` is half of it, etc. affine : [list of] tensor, optional dim : int, optional bound : str, default='dct2' extrapolate : bool, default=True method : {'gauss', 'average', 'median', 'stride'}, default='gauss' """ # I don't call super().__init__() on purpose self.method = method if isinstance(dat, Image): if affine is None: affine = dat.affine dim = dat.dim mask = dat.mask preview = dat._preview bound = dat.bound extrapolate = dat.extrapolate dat = dat.dat if isinstance(dat, str): dat = io.map(dat) if isinstance(dat, io.MappedArray): if affine is None: affine = dat.affine dat = dat.fdata(rand=True, **backend)[None] if isinstance(levels, int): levels = range(levels) if isinstance(levels, range): levels = list(levels) if not isinstance(dat, (list, tuple)): dim = dim or dat.dim() - 1 if not levels: raise ValueError('levels required to compute pyramid') if isinstance(levels, int): levels = range(1, levels+1) if affine is None: shape = dat.shape[-dim:] affine = spatial.affine_default(shape, **utils.backend(dat[0])) dat, mask, preview = self._build_pyramid( dat, levels, method, dim, bound, mask, preview) dat = list(dat) if not mask: mask = [None] * len(dat) if not preview: preview = [None] * len(dat) if all(isinstance(d, Image) for d in dat): self._dat = dat return dim = dim or (dat[0].dim() - 1) if affine is None: shape = dat[0].shape[-dim:] affine = spatial.affine_default(shape, **utils.backend(dat[0])) if not isinstance(affine, (list, tuple)): if not levels: raise ValueError('levels required to compute affine pyramid') shape = dat[0].shape[-dim:] affine = self._build_affine_pyramid(affine, shape, levels, method) self._dat = [Image(d, aff, dim=dim, mask=m, preview=p, bound=bound, extrapolate=extrapolate) for d, m, p, aff in zip(dat, mask, preview, affine)]
def align_tpm(dat, tpm=None, weights=None, spacing=(8, 4), device=None, basis='affine', joint=False, progressive=False, bins=256, fwhm=None, max_iter_gn=100, max_iter_em=32, max_line_search=6, verbose=1): """Align a Tissue Probability Map to an image Input Parameters ---------------- dat : file(s) or tensor or (tensor, affine) Input image(s) tpm : file(s) or tensor or (tensor, affine), optional Input tissue probability map. Uses SPM's TPM by default. weights : file(s) or tensor Input mask or weight map device : torch.device, optional Specify device verbose : int, default=1 0 = Write nothing 1 = Outer loop 2 = Line search 3 = Inner loop Option Parameters ----------------- spacing : float(s), default=(8, 3) Sampling distance in mm. If multiple value, coarse to fine fit. Larger is faster but less accurate. basis : {'trans', 'rot', 'rigid', 'sim', 'aff'}, default='affine' Transformation model joint : bool, default=False Estimate a single affine for all images progressive : bool, default=False Fit prameters progressively (translation then rigid then affine) Optimization parameters ----------------------- bins : int, default=256 Number of bins to use to discretize the input image fwhm : float, default=bins/64 Full-width at half-maximum used to smooth the joint histogram max_iter_gn : int, default=100 Maximumm number of Gauss-Newton iterations max_iter_em : int, default=32 Maximum number of EM iterations max_line_search : int, default=6 Maximum number of line search steps Returns ------- aff : ([B], 4, 4) tensor Affine matrix. Can be applied to the TPM by `aff \ tpm.affine` or to the image by `aff @ dat.affine`. """ # ------------------------------------------------------------------ # LOAD DATA # ------------------------------------------------------------------ affine_dat = affine_tpm = None if isinstance(dat, (list, tuple)) and torch.is_tensor(dat[0]): affine_dat = dat[1] if len(dat) > 1 else None dat = dat[0] if isinstance(tpm, (list, tuple)) and torch.is_tensor(tpm[0]): affine_tpm = tpm[1] if len(tpm) > 1 else None tpm = tpm[0] backend = get_backend(dat, tpm, device) tpm, affine_tpm = get_prior(tpm, affine_tpm, **backend) dim = tpm.dim() - 1 dat, weights, affine_dat = get_data(dat, weights, affine_dat, dim, **backend) if weights is None: weights = 1 weights = weights * torch.isfinite(dat) # ------------------------------------------------------------------ # DEFAULT ORIENTATION MATRICES # ------------------------------------------------------------------ if affine_tpm is not None: affine_tpm = affine_tpm.to(dat.dtype) else: affine_tpm = spatial.affine_default(tpm.shape[-dim:], **backend) if affine_dat is None: affine_dat = spatial.affine_default(dat.shape[-dim:], **backend) dat = dat.unsqueeze(1) # [B, 1, *spatial] weights = weights.unsqueeze(1) # [B, 1, *spatial] tpm = tpm.unsqueeze(0) # [1, K, *spatial] # ------------------------------------------------------------------ # DISCRETIZE # ------------------------------------------------------------------ dat = discretize(dat, nbins=bins, mask=weights) # ------------------------------------------------------------------ # OPTIONS # ------------------------------------------------------------------ opt = dict( basis=basis, joint=joint, progressive=progressive, fwhm=fwhm, max_iter_gn=max_iter_gn, max_iter_em=max_iter_em, max_line_search=max_line_search, verbose=verbose, ) # ------------------------------------------------------------------ # SPACING # ------------------------------------------------------------------ spacing = py.make_list(spacing) or [0] dat0, affine_dat0, weights0 = dat, affine_dat, weights vx = spatial.voxel_size(affine_dat0).tolist() prm = None for sp in spacing: if sp: sp = [max(1, int(pymath.floor(sp / vx1))) for vx1 in vx] sp = [slice(None, None, sp1) for sp1 in sp] affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:], tuple(sp)) dat = dat0[(Ellipsis, *sp)] if weights is not None: weights = weights0[(Ellipsis, *sp)] _, aff, prm = fit_affine_tpm(dat, tpm, affine_dat, affine_tpm, weights, **opt, prm=prm) return aff.squeeze()
def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None, basis='affine', fwhm=None, joint=False, prm=None, max_iter_gn=100, max_iter_em=32, max_line_search=6, progressive=False, verbose=1): """ Parameters ---------- dat : (B, J|1, *spatial) tensor tpm : (B|1, K, *spatial) tensor affine : (4, 4) tensor affine_tpm : (4, 4) tensor weights : (B, 1, *spatial) tensor basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'} fwhm : float, default=J/32 joint : bool, default=False max_iter_gn : int, default=100 max_iter_em : int, default=32 max_line_search : int, default=12 progressive : bool, default=False Returns ------- mi : (B,) tensor aff : (B, 4, 4) tensor prm : (B, F) tensor """ dim = dat.dim() - 2 # ------------------------------------------------------------------ # RECURSIVE PROGRESSIVE FIT # ------------------------------------------------------------------ if progressive: nb_se = dim * (dim + 1) // 2 nb_aff = dim * (dim + 1) basis_recursion = {'Aff+': 'CSO', 'CSO': 'SE', 'SE': 'T'} basis_nb_feat = {'Aff+': nb_aff, 'CSO': nb_se + 1, 'SE': nb_se} basis = convert_basis(basis) next_basis = basis_recursion.get(basis, None) if next_basis: *_, prm = fit_affine_tpm(dat, tpm, affine, affine_tpm, weights, basis=next_basis, fwhm=fwhm, joint=joint, prm=prm, max_iter_gn=max_iter_gn, max_iter_em=max_iter_em, max_line_search=max_line_search) B = len(dat) F = basis_nb_feat[basis] prm0 = prm prm = prm0.new_zeros([1 if joint else B, F]) if basis == 'SE': prm[:, :dim] = prm0[:, :dim] else: nb_se = dim * (dim + 1) // 2 prm[:, :nb_se] = prm0[:, :nb_se] if basis == 'Aff+': prm[:, nb_se:nb_se + dim] = prm0[:, nb_se] * (dim**(-0.5)) basis_name = basis # ------------------------------------------------------------------ # PREPARE # ------------------------------------------------------------------ B = len(dat) if affine is None: affine = spatial.affine_default(dat.shape[-dim:]) if affine_tpm is None: affine_tpm = spatial.affine_default(tpm.shape[-dim:]) affine = affine.to(**utils.backend(tpm)) affine_tpm = affine_tpm.to(**utils.backend(tpm)) shape = dat.shape[-dim:] tpm = tpm.to(dat.device) basis = make_basis(basis, dim, **utils.backend(tpm)) F = len(basis) if prm is None: prm = tpm.new_zeros([1 if joint else B, F]) aff, gaff = linalg._expm(prm, basis, grad_X=True) em_opt = dict(fwhm=fwhm, max_iter=max_iter_em, weights=weights, verbose=verbose - 2) drv_opt = dict(weights=weights) pull_opt = dict(bound='replicate', extrapolate=True) # ------------------------------------------------------------------ # OPTIMIZE # ------------------------------------------------------------------ prior = None mi = torch.as_tensor(-float('inf')) delta = torch.zeros_like(prm) for n_iter in range(max_iter_gn): # -------------------------------------------------------------- # LINE SEARCH # -------------------------------------------------------------- prior0, prm0, mi0 = prior, prm, mi armijo = 1 success = False for n_ls in range(max_line_search): # --- take a step ------------------------------------------ prm = prm0 - armijo * delta # --- build transformation field --------------------------- aff, gaff = linalg._expm(prm, basis, grad_X=True) phi = lmdiv(affine_tpm, mm(aff, affine)) phi = spatial.affine_grid(phi, shape) # --- warp TPM --------------------------------------------- mov = spatial.grid_pull(tpm, phi, **pull_opt) # --- mutual info ------------------------------------------ mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt) mi = mi / Nm success = mi.sum() > mi0.sum() if verbose >= 2: end = '\n' if verbose >= 3 else '\r' happy = ':D' if success else ':(' print(f'(search) | {n_ls:02d} | {mi.mean():12.6g} | {happy}', end=end) if success: break armijo *= 0.5 # if verbose == 2: # print('') # -------------------------------------------------------------- # DID IT WORK? # -------------------------------------------------------------- if not success: prior, prm, mi = prior0, prm0, mi0 break # DEBUG # plot_registration(dat, mov, f'{basis_name} | {n_iter}') space = ' ' * max(0, 6 - len(basis_name)) if verbose >= 1: end = '\n' if verbose >= 2 else '\r' print( f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}', end=end) if mi.mean() - mi0.mean() < 1e-5: break # -------------------------------------------------------------- # GAUSS-NEWTON # -------------------------------------------------------------- # --- derivatives ---------------------------------------------- g, h = derivatives_intensity(mov, dat, prior, **drv_opt) # --- chain rule ----------------------------------------------- gmov = spatial.grid_grad(tpm, phi, **pull_opt) if joint and len(mov) == 1: g = g.sum(0, keepdim=True) h = h.sum(0, keepdim=True) else: gmov = gmov.expand([B, *gmov.shape[1:]]) gaff = lmdiv(affine_tpm, mm(gaff, affine)) g, h = chain_rule(g, h, gmov, gaff, maj=False) del gmov if joint and len(g) > 1: g = g.sum(0, keepdim=True) h = h.sum(0, keepdim=True) # --- Gauss-Newton --------------------------------------------- delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1) if verbose == 1: print('') return mi, aff, prm
def vexp(inp, type='displacement', unit='voxel', inverse=False, bound='dft', steps=8, device=None, output=None): """Exponentiate a stationary velocity fields. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. type : {'displacement', 'transformation'}, default='displacement' Whether to return a displacement field (coord-to-shift) or a transformation field (coord-to-coord). unit : {'voxel', 'mm'}, default='voxel' Whether to return displacement/coordinates in voxel or in mm. If mm, the input orientation matrix is used to convert voxels to mm. inverse : bool, default=False Whether to return the inverse field. bound : str, default='dft' Boundary conditions. steps : int, default=8 Number of scaling and squaring steps. device : str, optional Device to use. output : str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.vexp{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file. Returns ------- output : str or (tensor, tensor) If the input is a path, the output path is returned. Else, the output tensor and orientation matrix are returned. """ dir = '' base = '' ext = '' fname = None # --- Open input --- is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.fdata(device=device), f.affine) if output is None: output = '{dir}{sep}{base}.vexp{ext}' dir, base, ext = py.fileparts(fname) else: if torch.is_tensor(inp): inp = (inp.clone(), spatial.affine_default(shape=inp.shape[:3])) dat, aff = inp dat = dat.to(device=device) aff = aff.to(device=device) # exponentiate dat = spatial.exp(dat[None], inverse=inverse, steps=steps, bound=bound, inplace=True, displacement=(type.lower()[0] == 'd'))[0] if unit == 'mm': # if type.lower()[0] == 'd': # vx = spatial.voxel_size(aff) # dat *= vx # else: dat = spatial.affine_matvec(aff, dat) if output: if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) io.volumes.save(dat, output, like=fname, affine=aff.cpu()) else: output = output.format(sep=os.path.sep) io.volumes.save(dat, output, affine=aff.cpu()) if is_file: return output else: return dat, aff
def affine(self, value): if value is None: self._affine = spatial.affine_default(self.shape) self._affine = value
def get_orthogonal_oriented_slices(image, index=None, affine=None, space=None, bbox=None, interpolation=1, transpose_sagittal=False, return_index=False, return_mat=False): """Sample orthogonal slices in a RAS system Parameters ---------- image : (..., *shape3) Input image index : (3,) sequence or tensor, default=shape//2 Coordinate (in voxel) of the slice to extract affine : (4, 4) tensor, optional Orientation matrix of the image space : (4, 4) tensor, optional Orientation matrix of the visualisation space. Default: RAS with minimum voxel size of all inputs. bbox : (2, D) tensor_like, optional Bounding box: min and max coordinates (in millimetric visualisation space). Default: bounding box of the input image. interpolation : {0, 1}, default=1 Interpolation order. Returns ------- slice : tuple of (..., *shape2) tensor Slices in the visualisation space. """ shape = image.shape[-3:] if affine is None: affine = spatial.affine_default(shape) affine = torch.as_tensor(affine) # compute default space (mn/mx are in voxels) affines = affine.reshape([-1, 4, 4]) shapes = [shape] * len(affines) space, mn, mx = _get_default_space(affines, shapes, space, bbox) voxel_size = spatial.voxel_size(space) mn *= voxel_size mx *= voxel_size prm = { 'index': index, 'affine': affine, 'space': space, 'bbox': [mn, mx], 'interpolation': interpolation, 'transpose_sagittal': transpose_sagittal, 'return_index': return_index, 'return_mat': return_mat } slices = tuple(get_oriented_slice(image, dim=d, **prm) for d in (-1, -2, -3)) if return_index and return_mat: return (tuple(sl[0] for sl in slices), tuple(sl[1] for sl in slices), tuple(sl[2] for sl in slices)) elif return_index or return_mat: return (tuple(sl[0] for sl in slices), tuple(sl[1] for sl in slices)) else: return slices
def orient(inp, layout=None, voxel_size=None, center=None, like=None, output=None): """Overwrite the orientation matrix Parameters ---------- inp : str or (tuple, tensor) Either a path to a volume file or a tuple `(shape, affine)`, where the first element contains the volume shape and the second contains the orientation matrix. layout : str or layout-like, default=None (= preserve) Target orientation. voxel_size : [sequence of] float, default=None (= preserve) Target voxel size. center : [sequence of] float, default=None (= preserve) World coordinate of the center of the field of view. like : str or (tuple, tensor) Either a path to a volume file or a tuple `(shape, affine)`, where the first element contains the volume shape and the second contains the orientation matrix. output : str, optional Output filename. If the input is not a path, the reoriented data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file. Returns ------- output : str or (tuple, tensor) If the input is a path, the output path is returned. Else, the new shape and orientation matrix are returned. """ dir = '' base = '' ext = '' fname = '' is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) dim = f.affine.shape[-1] - 1 inp = (f.shape[:dim], f.affine) if output is None: output = '{dir}{sep}{base}.{layout}{ext}' dir, base, ext = py.fileparts(fname) like_is_file = isinstance(like, str) and like if like_is_file: f = io.volumes.map(like) dim = f.affine.shape[-1] - 1 like = (f.shape[:dim], f.affine) shape, aff0 = inp dim = aff0.shape[-1] - 1 if like: shape_like, aff_like = like else: shape_like, aff_like = (shape, aff0) if voxel_size in (None, 'like') or len(voxel_size) == 0: voxel_size = spatial.voxel_size(aff_like) elif voxel_size == 'self': voxel_size = spatial.voxel_size(aff0) voxel_size = utils.make_vector(voxel_size, dim) if not layout or layout == 'like': layout = spatial.affine_to_layout(aff_like) elif layout == 'self': layout = spatial.affine_to_layout(aff0) layout = spatial.volume_layout(layout) if center in (None, 'like') or len(voxel_size) == 0: center = torch.as_tensor(shape_like, dtype=torch.float) * 0.5 center = spatial.affine_matvec(aff_like, center) elif center == 'self': center = torch.as_tensor(shape, dtype=torch.float) * 0.5 center = spatial.affine_matvec(aff0, center) center = utils.make_vector(center, dim) aff = spatial.affine_default(shape, voxel_size=voxel_size, layout=layout, center=center, dtype=torch.double) if output: dat = io.volumes.load(fname, numpy=True) layout = spatial.volume_layout_to_name(layout) if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep, layout=layout) io.volumes.save(dat, output, like=fname, affine=aff) else: output = output.format(sep=os.path.sep, layout=layout) io.volumes.save(dat, output, affine=aff) if is_file: return output else: return shape, aff