def orient(inp, affine=None, layout=None, voxel_size=None, center=None, like=None, output=None, output_transform=None): """Overwrite the orientation matrix Parameters ---------- inp : str or (tuple, tensor) Either a path to a volume file or a tuple `(shape, affine)`, where the first element contains the volume shape and the second contains the orientation matrix. affine : {'self', 'like'} or (4, 4) tensor_like, default='like' Target affine matrix layout : {'self', 'like'} or layout-like, default='like' Target orientation. voxel_size : {'self', 'like'} or [sequence of] float, default='like' Target voxel size. center : {'self', 'like'} or [sequence of] float, default='like' World coordinate of the center of the field of view. like : str or (tuple, tensor) Either a path to a volume file or a tuple `(shape, affine)`, where the first element contains the volume shape and the second contains the orientation matrix. output : str, optional Output filename. If the input is not a path, the reoriented data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file. output_transform : str, optional Filename of output transform. If the input is not a path, the reoriented data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}_to_{layout}.lta', where `dir` and `base` are the directory and base name of the input file. Returns ------- output : str or (tuple, tensor) If the input is a path, the output path is returned. Else, the new shape and orientation matrix are returned. """ dir = '' base = '' ext = '' fname = '' is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) dim = f.affine.shape[-1] - 1 inp = (f.shape[:dim], f.affine) if output is None: output = '{dir}{sep}{base}.{layout}{ext}' if output_transform is None: output_transform = '{dir}{sep}{base}_to_{layout}.lta' dir, base, ext = py.fileparts(fname) like_is_file = isinstance(like, str) and like if like_is_file: f = io.volumes.map(like) dim = f.affine.shape[-1] - 1 like = (f.shape[:dim], f.affine) shape, aff0 = inp dim = aff0.shape[-1] - 1 if like: shape_like, aff_like = like else: shape_like, aff_like = (shape, aff0) if voxel_size in (None, 'like') or len(voxel_size) == 0: voxel_size = spatial.voxel_size(aff_like) elif voxel_size == 'self': voxel_size = spatial.voxel_size(aff0) elif voxel_size == 'standard': voxel_size = 1. voxel_size = utils.make_vector(voxel_size, dim) if not layout or layout == 'like': layout = spatial.affine_to_layout(aff_like) elif layout == 'self': layout = spatial.affine_to_layout(aff0) elif layout == 'standard': layout = 'RAS' layout = spatial.volume_layout(layout) if center in (None, 'like') or len(center) == 0: center = (torch.as_tensor(shape_like, dtype=torch.float) - 1) * 0.5 center = spatial.affine_matvec(aff_like, center) elif center == 'self': center = (torch.as_tensor(shape, dtype=torch.float) - 1) * 0.5 center = spatial.affine_matvec(aff0, center) elif center == 'standard': center = 0. center = utils.make_vector(center, dim) if affine in (None, 'like') or len(affine) == 0: affine = aff_like elif affine == 'self': affine = aff0 elif affine == 'standard': affine = torch.eye(dim+1, dim+1) affine = torch.as_tensor(affine, dtype=torch.float) if affine.numel() == dim*(dim+1): affine = spatial.affine_make_rect(affine.reshape(dim, dim+1)) elif affine.numel() == (dim+1)**2: affine = affine.reshape(dim+1, dim+1) else: raise ValueError(f'Input affine should have {dim*(dim+1)} or ' f'{(dim+1)**2} element but got {affine.numel()}.') affine = spatial.affine_modify(affine, shape, voxel_size=voxel_size, layout=layout, center=center) affine = affine.double() if output: dat = io.volumes.load(fname, numpy=True) layout = spatial.volume_layout_to_name(layout) if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep, layout=layout) io.volumes.save(dat, output, like=fname, affine=affine) else: output = output.format(sep=os.path.sep, layout=layout) io.volumes.save(dat, output, affine=affine) if output_transform: transform = spatial.affine_rmdiv(affine, aff0) output_transform = output_transform.format( dir=dir or '.', base=base, sep=os.path.sep, layout=layout) io.transforms.savef(transform.cpu(), output_transform, type=2) if is_file: return output else: return shape, affine
def forward(self, grid, **overload): """ Parameters ---------- grid : (N, *spatial, dim) Displacement grid overload : dict Returns ------- aff : (N, dim+1, dim+1) Affine matrix that is closest to grid in the least square sense """ shift = overload.get('shift', self.shift) grid = torch.as_tensor(grid) info = dict(dtype=grid.dtype, device=grid.device) nb_dim = grid.shape[-1] shape = grid.shape[1:-1] if shift: affine_shift = torch.cat((torch.eye( nb_dim, **info), -torch.as_tensor(shape, **info)[:, None] / 2), dim=1) affine_shift = spatial.as_euclidean(affine_shift) # the forward model is: # phi(x) = M\A*M*x # where phi is a *transformation* field, M is the shift matrix # and A is the affine matrix. # We can decompose phi(x) = x + d(x), where d is a *displacement* # field, yielding: # d(x) = M\A*M*x - x = (M\A*M - I)*x := B*x # If we write `d(x)` and `x` as large vox*(dim+1) matrices `D` # and `G`, we have: # D = G*B' # Therefore, the least squares B is obtained as: # B' = inv(G'*G) * (G'*D) # Then, A is # A = M*(B + I)/M # # Finally, we project the affine matrix to its tangent space: # prm[k] = <log(A), B[k]> # were <X,Y> = trace(X'*Y) is the Frobenius inner product. def igg(identity): # Compute inv(g*g'), where g has homogeneous coordinates. # Instead of appending ones, we compute each element of # the block matrix ourselves: # [[g'*g, g'*1], # [1'*g, 1'*1]] # where 1'*1 = N, the number of voxels. g = identity.reshape([identity.shape[0], -1, nb_dim]) nb_vox = torch.as_tensor([[[g.shape[1]]]], **info) sumg = g.sum(dim=1, keepdim=True) gg = torch.matmul(g.transpose(-1, -2), g) gg = torch.cat((gg, sumg), dim=1) sumg = sumg.transpose(-1, -2) sumg = torch.cat((sumg, nb_vox), dim=1) gg = torch.cat((gg, sumg), dim=2) return gg.inverse() def gd(identity, disp): # compute g'*d, where g and d have homogeneous coordinates. # [[g'*d, g'*1], # [1'*d, 1'*1]] g = identity.reshape([identity.shape[0], -1, nb_dim]) d = disp.reshape([disp.shape[0], -1, nb_dim]) nb_vox = torch.as_tensor([[[g.shape[1]]]], **info) sumg = g.sum(dim=1, keepdim=True) sumd = d.sum(dim=1, keepdim=True) gd = torch.matmul(g.transpose(-1, -2), d) gd = torch.cat((gd, sumd), dim=1) sumg = sumg.transpose(-1, -2) sumg = torch.cat((sumg, nb_vox), dim=1) sumg = sumg.expand([d.shape[0], sumg.shape[1], sumg.shape[2]]) gd = torch.cat((gd, sumg), dim=2) return gd def eye(d): x = torch.eye(d, **info) z = x.new_zeros([1, d], **info) x = torch.cat((x, z), dim=0) z = x.new_zeros([d + 1, 1], **info) x = torch.cat((x, z), dim=1) return x identity = spatial.identity_grid(shape, **info)[None, ...] affine = torch.matmul(igg(identity), gd(identity, grid)) affine = affine.transpose(-1, -2) + eye(nb_dim) affine = affine[..., :-1, :] if shift: affine = spatial.as_euclidean(affine) affine = spatial.affine_matmul(affine_shift, affine) affine = spatial.as_euclidean(affine) affine = spatial.affine_rmdiv(affine, affine_shift) affine = spatial.affine_make_square(affine) return affine
def get_oriented_slice(image, dim=-1, index=None, affine=None, space=None, bbox=None, interpolation=1, transpose_sagittal=False, return_index=False, return_mat=False): """Sample a slice in a RAS system Parameters ---------- image : (..., *shape3) dim : int, default=-1 Index of spatial dimension to sample in the visualization space If RAS: -1 = axial / -2 = coronal / -3 = sagittal index : int, default=shape//2 Coordinate (in voxel) of the slice to extract affine : (4, 4) tensor, optional Orientation matrix of the image space : (4, 4) tensor, optional Orientation matrix of the visualisation space. Default: RAS with minimum voxel size of all inputs. bbox : (2, D) tensor_like, optional Bounding box: min and max coordinates (in millimetric visualisation space). Default: bounding box of the input image. interpolation : {0, 1}, default=1 Interpolation order. Returns ------- slice : (..., *shape2) tensor Slice in the visualisation space. """ # preproc dim if isinstance(dim, str): dim = dim.lower()[0] if dim == 'a': dim = -1 if dim == 'c': dim = -2 if dim == 's': dim = -3 backend = utils.backend(image) # compute default space (mn/mx are in voxels) affine, shape = _get_default_native(affine, image.shape[-3:]) space, mn, mx = _get_default_space(affine, [shape], space, bbox) affine, shape = (affine[0], shape[0]) # compute default cursor (in voxels) if index is None: index = (mx + mn) / 2 else: index = torch.as_tensor(index) index = spatial.affine_matvec(spatial.affine_inv(space), index) # include slice to volume matrix shape = tuple(((mx-mn) + 1).round().int().tolist()) if dim == -1: # axial shift = [[1, 0, 0, - mn[0] + 1], [0, 1, 0, - mn[1] + 1], [0, 0, 1, - index[2]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = shape[:-1] index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1) elif dim == -2: # coronal shift = [[1, 0, 0, - mn[0] + 1], [0, 0, 1, - mn[2] + 1], [0, 1, 0, - index[1]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[0], shape[2]) index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1) elif dim == -3: # sagittal if not transpose_sagittal: shift = [[0, 0, 1, - mn[2] + 1], [0, 1, 0, - mn[1] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[2], shape[1]) index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1) else: shift = [[0, -1, 0, mx[1] + 1], [0, 0, 1, - mn[2] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[1], shape[2]) index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1) else: raise ValueError(f'Unknown dimension {dim}') # sample space = spatial.affine_rmdiv(space, shift) affine = spatial.affine_lmdiv(affine, space) affine = affine.to(**backend) grid = spatial.affine_grid(affine, [*shape, 1]) *channel, s0, s1, s2 = image.shape imshape = (s0, s1, s2) image = image.reshape([1, -1, *imshape]) image = spatial.grid_pull(image, grid[None], interpolation=interpolation, bound='dct2', extrapolate=False) image = image.reshape([*channel, *shape]) return ((image, index, space) if return_index and return_mat else (image, index) if return_index else (image, space) if return_mat else image)