def _get_default_native(affines, shapes): """Get default native space Parameters ---------- affines : [sequence of] (4, 4) tensor_like or None shapes : [sequence of] (3,) tensor_like Returns ------- affines : (N, 4, 4) tensor shapes : (N, 3) tensor """ shapes = utils.as_tensor(shapes).reshape([-1, 3]) shapes = shapes.unbind(dim=0) if torch.is_tensor(affines): affines = affines.reshape([-1, 4, 4]) affines = affines.unbind(dim=0) shapes = py.make_list(shapes, max(len(shapes), len(affines))) affines = py.make_list(affines, max(len(shapes), len(affines))) # default affines affines = [spatial.affine_default(shape) if affine is None else affine for shape, affine in zip(shapes, affines)] affines = utils.as_tensor(affines) shapes = utils.as_tensor(shapes) affines, shapes = utils.to_max_device(affines, shapes) return affines, shapes
def bracket(self, f0, closure): """Bracket the minimum Parameters ---------- f0 : Initial value closure : callable(a) -> evaluate function at `a0 + a` Returns ------- (a0, f0), (a1, f1), (a2, f2) """ a0, a1 = 0, self.lr f1 = closure(a1) # sort such that f1 < f0 if f1 > f0: a0, f0, a1, f1 = a1, f1, a0, f0 a2 = a1 + self.gold * (a1 - a0) f2 = closure(a2) while f1 > f2: # fit quadratic polynomial a = utils.as_tensor([a0 - a1, 0., a2 - a1]) quad = torch.stack([torch.ones_like(a), a, a.square()], -1) f = utils.as_tensor([f0, f1, f2]).unsqueeze(-1) quad = quad.pinverse().matmul(f).squeeze(-1) if quad[2] > 0: # There is a minimum delta = -0.5 * quad[1] / quad[2].clamp_min(self.tiny) delta = delta.clamp_max_((1 + self.gold) * (a2 - a1)) delta = delta.item() a = a1 + delta else: # No minimum -> we must go farther than a2 delta = self.gold * (a2 - a1) a = a2 + delta # check progress and update bracket # f2 < f1 < f0 so (assuming unicity) the minimum is in # (a1, a2) or (a2, inf) f = closure(a) if a1 < a < a2 or a2 < a < a1: if f1 < f < f2: # minimum is in (a1, a2) (a0, f0), (a1, f1), (a2, f2) = (a1, f1), (a, f), (a2, f2) break elif f1 < f: # implicitly: f0 < f1 < f # minimum is in (a0, a) (a0, f0), (a1, f1), (a2, f2) = (a0, f0), (a1, f1), (a, f) break # shift by one point (a0, f0), (a1, f1), (a2, f2) = (a1, f1), (a2, f2), (a, f) return (a0, f0), (a1, f1), (a2, f2)
def search_in_bracket(self, bracket, closure): b0, b1 = (bracket[0][0], bracket[2][0]) if b1 < b0: b0, b1 = b1, b0 # sort by values (a0, f0), (a1, f1), (a2, f2) = self.sort_bracket(bracket) d = d0 = float('inf') for n_iter in range(self.max_iter): if abs(a0 - 0.5 * (b0 + b1)) + 0.5 * (b1 - b0) <= 2 * self.tol: return a0, f0 d1, d0 = d0, d # fit quadratic polynomial a = utils.as_tensor([0., a1 - a0, a2 - a0]) quad = torch.stack([torch.ones_like(a), a, a.square()], -1) f = utils.as_tensor([f0, f1, f2]).unsqueeze(-1) quad = quad.pinverse().matmul(f).squeeze(-1) d = -0.5 * quad[1] / quad[2].clamp_min(self.tiny) d = d.item() a = a0 + d tiny = self.tiny * (1 + 2 * abs(a0)) if abs(d) > abs(d1) / 2 or not (b0 + tiny < a < b1 - tiny) or quad[-1] < 0: if a0 > 0.5 * (b0 + b1): d = self.igold * (b0 - a0) else: d = self.igold * (b1 - a0) a = a0 + d # check progress and update bracket f = closure(a) if f < f0: # f < f0 < f1 < f2 b0, b1 = (b0, a0) if a < a0 else (a0, b1) (a0, f0), (a1, f1), (a2, f2) = (a, f), (a0, f0), (a1, f1) else: b0, b1 = (a, b1) if a < a0 else (b0, a) if f < f1: # f0 < f < f1 < f2 (a0, f0), (a1, f1), (a2, f2) = (a0, f0), (a, f), (a1, f1) elif f < f2: # f0 < f1 < f < f2 (a0, f0), (a1, f1), (a2, f2) = (a0, f0), (a1, f1), (a, f) return a0, f0
def forward(self, x, noise=None, return_resolution=False): if noise is not None: noise = noise.expand(x.shape) dim = x.dim() - 2 backend = utils.backend(x) resolution_exp = utils.make_vector(self.resolution_exp, x.shape[1], **backend) resolution_scale = utils.make_vector(self.resolution_scale, x.shape[1], **backend) all_resolutions = [] out = torch.empty_like(x) for b in range(len(x)): for c in range(x.shape[1]): resolution = self.resolution(resolution_exp[c], resolution_scale[c]).sample() resolution = resolution.clamp_min(1) fwhm = [resolution] * dim y = smooth(x[b, c], fwhm=fwhm, dim=dim, padding='same', bound='dct2') if noise is not None: y += noise[b, c] factor = [1/resolution] * dim y = y[None, None] # need batch and channel for resize y = resize(y, factor=factor, anchor='f') factor = [resolution] * dim all_resolutions.append(factor) y = resize(y, factor=factor, shape=x.shape[2:], anchor='f') out[b, c] = y[0, 0] all_resolutions = utils.as_tensor(all_resolutions, **backend) return (out, all_resolutions) if return_resolution else out
def _get_default_space(affines, shapes, space=None, bbox=None): """Get default visualisation space Parameters ---------- affines : [sequence of] (4, 4) tensor_like shapes : [sequence of] (3,) tensor_like space : (4, 4) tensor_like, optional bbox : (2, 3) tensor_like, optional Returns ------- space, bbox """ affines, shapes = _get_default_native(affines, shapes) voxel_size = spatial.voxel_size(affines) voxel_size = voxel_size.min() if space is None: space = torch.eye(4) space[:-1, :-1] *= voxel_size voxel_size = spatial.voxel_size(space) if bbox is None: shapes = torch.as_tensor(shapes) mn, mx = spatial.compute_fov(space, affines, shapes) else: mn, mx = utils.as_tensor(bbox) mn /= voxel_size mx /= voxel_size return space, mn, mx
def rotation(self): i = self.i j = self.j k = self.k r = self.r matrix = [ [1 - 2 * (j**2 + k**2), 2 * (i * j - k * r), 2 * (i * k + j * r)], [2 * (i * j + k * r), 1 - 2 * (i**2 + k**2), 2 * (j * k - i * r)], [2 * (i * k - j * r), 2 * (j * k + i * r), 1 - 2 * (i**2 + j**2)] ] matrix = utils.as_tensor(matrix) matrix = utils.movedim(matrix, [0, 1], [-2, -1]) return matrix
def space(self, value): self._space = value if torch.is_tensor(value): if value.shape != (4, 4): raise ValueError('Expected 4x4 matrix') self._space_matrix = value elif isinstance(value, int): affines = [image.affine for image in self.images] self._space_matrix = affines[value] else: if value is not None: raise ValueError('Expected a 4x4 matrix or an int or None') affines = [image.affine for image in self.images] voxel_size = spatial.voxel_size(utils.as_tensor(affines)) voxel_size = voxel_size.min() self._space_matrix = torch.eye(4) self._space_matrix[:-1, :-1] *= voxel_size
def crop(inp, size=None, center=None, space='vx', like=None, bbox=False, output=None, transform=None): """Crop a ND volume, while preserving the orientation matrices. Parameters ---------- inp : str or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. size : [sequence of] int, optional Size of the patch to extract. Its unit and axes are defined by `units` and `layout`. center : [sequence of] int, optional Coordinate of the center of the patch. Its unit and axes are defined by `units` and `layout`. By default, the center of the FOV is used. space : [sequence of] {'vox', 'ras'}, default='vox' The space in which the `size` and `center` parameters are expressed. bbox : bool or float, default=False Crop at the bounding box of `inp > threshold`. If `bbox` is a float, it is the threshold to use. If `bbox` is `True`, the threshold is 0. like : str or (tensor, tensor), optional Reference patch. Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. output : [sequence of] str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file, `i` is the coordinate (starting at 1) of the slice. transform : [sequence of] str, optional Input or output filename(s) of the corresponding transforms. Not written by default. If a transform is provided and all other parameters (i.e., `size` and `like`) are None, the transform is considered as an input transform to apply. Returns ------- output : list[str or (tensor, tensor)] If the input is a path, the output paths are returned. Else, the unstacked data and orientation matrices are returned. """ dir = '' base = '' ext = '' fname = None transform_in = False use_bbox = bool(bbox or isinstance(bbox, float)) # --- Open input --- is_file = isinstance(inp, str) if is_file: fname = inp f = io.volumes.map(inp) inp = (f.data(numpy=True) if use_bbox else f, f.affine) if output is None: output = '{dir}{sep}{base}.crop{ext}' dir, base, ext = py.fileparts(fname) dat, aff0 = inp dim = aff0.shape[-1] - 1 shape0 = dat.shape[:dim] layout0 = spatial.affine_to_layout(aff0) # save input space in case we reorient later aff00 = aff0 shape00 = shape0 if bool(size) + bool(like) + bool(bbox or isinstance(bbox, float)) > 1: raise ValueError('Can only use one of `size`, `like` and `bbox`.') # --- Open reference and compute size/center --- if like: like_is_file = isinstance(like, str) if like_is_file: f = io.volumes.map(like) like = (f.shape, f.affine) like_shape, like_aff = like like_layout = spatial.affine_to_layout(like_aff) if (layout0 != like_layout).any(): aff0, dat = spatial.affine_reorient(aff0, dat, like_layout) shape0 = dat.shape[:dim] if torch.is_tensor(like_shape): like_shape = like_shape.shape size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape) space = 'vox' elif bbox or isinstance(bbox, float): if bbox is True: bbox = 0. mask = torch.as_tensor(dat > bbox) while mask.dim() > 3: mask = mask.any(dim=-1) mins = [] maxs = [] for d in range(dim): n = mask.shape[d] idx = utils.movedim(mask, d, 0).reshape([n, -1 ]).any(-1).nonzero(as_tuple=False) mins.append(idx.min()) maxs.append(idx.max()) mins = utils.as_tensor(mins) maxs = utils.as_tensor(maxs) size = maxs + 1 - mins center = (maxs + 1 + mins).float() / 2 space = 'vox' del mask # --- Open transformation file and compute size/center --- elif not size: if not transform: raise ValueError('At least one of size/like/transform must ' 'be provided') transform_in = True t = io.transforms.map(transform) if not isinstance(t, io.transforms.LinearTransformArray): raise TypeError('Expected an LTA file') like_aff, like_shape = t.destination_space() size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape) # --- use center of the FOV --- if not torch.is_tensor(center) and not center: center = torch.as_tensor(shape0[:dim], dtype=torch.float) center = center.sub_(1).mul_(0.5) # --- convert size/center to voxels --- size = utils.make_vector(size, dim, dtype=torch.long) center = utils.make_vector(center, dim, dtype=torch.float) space_size, space_center = py.make_list(space, 2) if space_center.lower() == 'ras': center = spatial.affine_matvec(spatial.affine_inv(aff0), center) if space_size.lower() == 'ras': perm = spatial.affine_to_layout(aff0)[:, 0] size = size[perm.long()] size = size / spatial.voxel_size(aff0) # --- compute first/last --- center = center.float() size = (size.ceil() if size.dtype.is_floating_point else size).long() first = center - size.float().sub_(1).mul_(0.5) first = first.round().long() last = (first + size).tolist() first = [max(f, 0) for f in first.tolist()] last = [min(l, s) for l, s in zip(last, shape0[:dim])] verb = 'Cropping patch [' verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)]) verb += f'] from volume with shape {shape0[:dim]}' print(verb) slicer = tuple(slice(f, l) for f, l in zip(first, last)) # --- do crop --- if use_bbox: dat = dat.numpy() dat = dat[slicer] if not torch.is_tensor(dat): dat = dat.data(numpy=True) aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer) shape = dat.shape[:dim] if output: if is_file: output = output.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) io.volumes.save(dat, output, like=fname, affine=aff) else: output = output.format(sep=os.path.sep) io.volumes.save(dat, output, affine=aff) if transform and not transform_in: if is_file: transform = transform.format(dir=dir or '.', base=base, ext=ext, sep=os.path.sep) else: transform = transform.format(sep=os.path.sep) trf = io.transforms.LinearTransformArray(transform, 'w') trf.set_source_space(aff00, shape00) trf.set_destination_space(aff, shape) trf.set_metadata({ 'src': { 'filename': fname }, 'dst': { 'filename': output }, 'type': 1 }) # RAS_TO_RAS trf.set_fdata(torch.eye(4)) trf.save() if is_file: return output else: return dat, aff
def is_inside(points, vertices, faces=None): """Test if a point is inside a polygon/surface. The polygon or surface *must* be closed. Parameters ---------- points : (..., dim) tensor Coordinates of points to test vertices : (nv, dim) tensor Vertex coordinates faces : (nf, dim) tensor[int] Faces are encoded by the indices of its vertices. By default, assume that vertices are ordered and define a closed curve Returns ------- check : (...) tensor[bool] """ # This function uses a ray-tracing technique: # # A half-line is started in each point. If it crosses an even # number of faces, it is inside the shape. If it crosses an even # number of faces, it is not. # # In practice, we loop through faces (as we expect there are much # less vertices than voxels) and compute intersection points between # all lines and each face in a batched fashion. We only want to # send these rays in one direction, so we keep aside points whose # intersection have a positive coordinate along the ray. points = torch.as_tensor(points) vertices = torch.as_tensor(vertices) if faces is None: faces = [(i, i + 1) for i in range(len(vertices) - 1)] faces += [(len(vertices) - 1, 0)] faces = utils.as_tensor(faces, dtype=torch.long) points, vertices = utils.to_max_dtype(points, vertices) points, vertices, faces = utils.to_max_device(points, vertices, faces) backend = utils.backend(points) batch = points.shape[:-1] dim = points.shape[-1] eps = constants.eps(points.dtype) cross = points.new_zeros(batch, dtype=torch.long) ray = torch.randn(dim, **backend) for face in faces: face = vertices[face] # compute normal vector origin = face[0] if dim == 3: u = face[1] - face[0] v = face[2] - face[0] norm = torch.stack([ u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2], u[0] * v[1] - u[1] * v[0] ]) else: assert dim == 2 u = face[1] - face[0] norm = torch.stack([-u[1], u[0]]) # check co-linearity between face and ray colinear = linalg.dot(ray, norm).abs() / (ray.norm() * norm.norm()) < eps if colinear: continue # compute intersection between ray and plane # plane: <norm, x - origin> = 0 # line: x = p + t*u # => <norm, p + t*u - origin> = 0 intersection = linalg.dot(norm, points - origin) intersection /= linalg.dot(norm, ray) halfmask = intersection >= 0 # we only want to shoot in one direction intersection = intersection[halfmask] halfpoints = points[halfmask] intersection = intersection[..., None] * (-ray) intersection += halfpoints # check if the intersection is inside the face # first, we project it onto a frame of dimension `dim-1` # defined by (origin, (u, v)) intersection -= origin if dim == 3: interu = linalg.dot(intersection, u) interv = linalg.dot(intersection, v) intersection = (interu >= 0) & (interv > 0) & (interu + interv < 1) else: intersection = linalg.dot(intersection, u) intersection /= u.norm().square_() intersection = (intersection >= 0) & (intersection < 1) cross[halfmask] += intersection # check that the number of crossings is even cross = cross.bitwise_and_(1).bool() return cross
def _max_fov(self): affines = [image.affine for image in self.images] shapes = [image.shape for image in self.images] affines = utils.as_tensor(affines) shapes = utils.as_tensor(shapes) return spatial.compute_fov(self._space_matrix, affines, shapes)
def _index_from_cursor(self, x, y, image, n_ax): p = utils.as_tensor([x, y, 0]) mat = image._mats[n_ax] self.index = spatial.affine_matvec(mat, p)
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict All parameters defined at build time can be overridden at call time Returns ------- affine : (batch, dim[+1], dim+1) tensor Velocity field """ dim = overload.get('dim', self.dim) translation = make_list(overload.get('translation', self.translation)) rotation = make_list(overload.get('rotation', self.rotation)) zoom = make_list(overload.get('zoom', self.zoom)) shear = make_list(overload.get('shear', self.shear)) dtype = make_list(overload.get('dtype', self.dtype)) device = make_list(overload.get('device', self.device)) # compute dimension dim = dim or max(len(translation), len(rotation), len(zoom), len(shear)) translation = make_list(translation, dim) rotation = make_list(rotation, dim * (dim - 1) // 2) zoom = make_list(zoom, dim) shear = make_list(shear, dim * (dim - 1) // 2) # sample values if needed translation = [ x([batch]) if callable(x) else self.default_translation([batch]) if x is True else 0. if x is None or x is False else x for x in translation ] rotation = [ x([batch]) if callable(x) else self.default_rotation([batch]) if x is True else 0. if x is None or x is False else x for x in rotation ] zoom = [ x([batch]) if callable(x) else self.default_zoom([batch]) if x is True else 1. if x is None or x is False else x for x in zoom ] shear = [ x([batch]) if callable(x) else self.default_shear([batch]) if x is True else 0. if x is None or x is False else x for x in shear ] rotation = [x * math.pi / 180 for x in rotation] # degree -> radian prm = [*translation, *rotation, *zoom, *shear] prm = [ p.expand(batch) if torch.is_tensor(p) and p.shape[0] != batch else make_list(p, batch) if not torch.is_tensor(p) else p for p in prm ] prm = utils.as_tensor(prm) prm = prm.transpose(0, 1) # generate affine matrix mat = affine_matrix_classic(prm, dim=dim).\ type(self.dtype).to(self.device) return mat
def graph_atlas(velocities, nodes=None, latent=None): """Infer template-to-image SVFs from pairwise SVFs.. References ---------- ..[1] "Robust joint registration of multiple stains and MRI for multimodal 3D histology reconstruction: Application to the Allen human brain atlas" Adrià Casamitjana, Marco Lorenzi, Sebastiano Ferraris, Loic Peter, Marc Modat, Allison Stevens, Bruce Fischl, Tom Vercauteren, Juan Eugenio Iglesias https://arxiv.org/abs/2104.14873 Parameters ---------- velocities : (n, *spatial, dim) tensor nodes : n-sequence of (int, int), default=[(1, 2), (2, 3), ..., (N, 1)] latent : k-sequence of (int, int) tensor, default=[(0, 1), ..., (0, N)] Returns ------- latent : sequence of (k, *spatial, dim) tensor """ def default_nodes(n): nodes = [] for n in range(1, n): nodes += [(n, n + 1)] nodes += [(n, 1)] return nodes def default_latent(n): nodes = [(0, n) for n in range(1, n + 1)] return nodes velocities = utils.as_tensor(velocities) backend = utils.backend(velocities) # defaults observed_nodes = list(nodes or default_nodes(len(velocities))) latent_nodes = list(latent or default_latent(len(velocities))) # compute W connections = _build_matrix(observed_nodes, latent_nodes) coordinates = [(latent_nodes.index(j), i) for i, c in enumerate(connections) for j in c] values = [v for c in connections for v in c.values()] values = torch.as_tensor(values, dtype=torch.float) coordinates = torch.as_tensor(coordinates, dtype=torch.long).T w = torch.sparse_coo_tensor( coordinates, values, [len(latent_nodes), len(observed_nodes)], **backend) # re-parameterise last velocity to enforce `sum of velocities = 0` w = w.to_dense() wlast = w[-1:, :] w = w[:-1, :] w -= wlast velocities = utils.movedim(velocities, 0, -1)[..., None] latent = w.transpose(-1, -2).pinverse().matmul(velocities)[..., 0] latent = utils.movedim(latent, -1, 0) latent = torch.cat([latent, -latent.sum(0, keepdim=True)], dim=0) return latent
def get_oriented_slice(image, dim=-1, index=None, affine=None, space=None, bbox=None, interpolation=1, transpose_sagittal=False, return_index=False, return_mat=False): """Sample a slice in a RAS system Parameters ---------- image : (..., *shape3) dim : int, default=-1 Index of spatial dimension to sample in the visualization space If RAS: -1 = axial / -2 = coronal / -3 = sagittal index : int, default=shape//2 Coordinate (in voxel) of the slice to extract affine : (4, 4) tensor, optional Orientation matrix of the image space : (4, 4) tensor, optional Orientation matrix of the visualisation space. Default: RAS with minimum voxel size of all inputs. bbox : (2, D) tensor_like, optional Bounding box: min and max coordinates (in millimetric visualisation space). Default: bounding box of the input image. interpolation : {0, 1}, default=1 Interpolation order. Returns ------- slice : (..., *shape2) tensor Slice in the visualisation space. """ # preproc dim if isinstance(dim, str): dim = dim.lower()[0] if dim == 'a': dim = -1 if dim == 'c': dim = -2 if dim == 's': dim = -3 backend = utils.backend(image) # compute default space (mn/mx are in voxels) affine, shape = _get_default_native(affine, image.shape[-3:]) space, mn, mx = _get_default_space(affine, [shape], space, bbox) affine, shape = (affine[0], shape[0]) # compute default cursor (in voxels) if index is None: index = (mx + mn) / 2 else: index = torch.as_tensor(index) index = spatial.affine_matvec(spatial.affine_inv(space), index) # include slice to volume matrix shape = tuple(((mx-mn) + 1).round().int().tolist()) if dim == -1: # axial shift = [[1, 0, 0, - mn[0] + 1], [0, 1, 0, - mn[1] + 1], [0, 0, 1, - index[2]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = shape[:-1] index = (index[0] - mn[0] + 1, index[1] - mn[1] + 1) elif dim == -2: # coronal shift = [[1, 0, 0, - mn[0] + 1], [0, 0, 1, - mn[2] + 1], [0, 1, 0, - index[1]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[0], shape[2]) index = (index[0] - mn[0] + 1, index[2] - mn[2] + 1) elif dim == -3: # sagittal if not transpose_sagittal: shift = [[0, 0, 1, - mn[2] + 1], [0, 1, 0, - mn[1] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[2], shape[1]) index = (index[2] - mn[2] + 1, index[1] - mn[1] + 1) else: shift = [[0, -1, 0, mx[1] + 1], [0, 0, 1, - mn[2] + 1], [1, 0, 0, - index[0]], [0, 0, 0, 1]] shift = utils.as_tensor(shift, **backend) shape = (shape[1], shape[2]) index = (mx[1] + 1 - index[1], index[2] - mn[2] + 1) else: raise ValueError(f'Unknown dimension {dim}') # sample space = spatial.affine_rmdiv(space, shift) affine = spatial.affine_lmdiv(affine, space) affine = affine.to(**backend) grid = spatial.affine_grid(affine, [*shape, 1]) *channel, s0, s1, s2 = image.shape imshape = (s0, s1, s2) image = image.reshape([1, -1, *imshape]) image = spatial.grid_pull(image, grid[None], interpolation=interpolation, bound='dct2', extrapolate=False) image = image.reshape([*channel, *shape]) return ((image, index, space) if return_index and return_mat else (image, index) if return_index else (image, space) if return_mat else image)
def kernel_fit(calib, kernel_size, patterns, lam=0.01): """Compute GRAPPA kernels All batch elements should have the same sampling pattern Parameters ---------- calib : ([*batch], coils, *freq) Fully-sampled calibration data kernel_size : sequence[int] GRAPPA kernel size patterns : (N,) tensor[int] Code of patterns for which to learn a kernel. See `pattern_to_code`. lam : float, default=0.01 Tikhonov regularization Returns ------- kernels : dict of int -> ([*batch], coils, coils, nb_elem) tensor GRAPPA kernels """ kernel_size = py.make_list(kernel_size) ndim = len(kernel_size) coils, *freq = calib.shape[-ndim - 1:] batch = calib.shape[:-ndim - 1] # find all possible patterns patterns = utils.as_tensor(patterns, device=calib.device) if patterns.dtype is torch.bool: patterns = pattern_to_code(patterns, ndim) patterns = patterns.flatten() # learn one kernel for each pattern calib = utils.unfold(calib, kernel_size, collapse=True) # [*B, C, N, *K] calib = utils.movedim(calib, -ndim - 1, -ndim - 2) # [*B, N, C, *K] def t(x): return x.transpose(-1, -2) def conjt(x): return t(x).conj() def diag(x): return x.diagonal(0, -1, -2) kernels = {} center = [(k - 1) // 2 for k in kernel_size] center = (Ellipsis, *center) for pattern_code in patterns: if code_has_center(pattern_code, kernel_size): continue pattern = code_to_pattern(pattern_code, kernel_size, device=calib.device) pattern_size = pattern.sum() if pattern_size == 0: continue calib_target = calib[center] # [*B, N, C] calib_source = calib[..., pattern] # [*B, N, C, P] calib_size = calib_target.shape[-2] flat_shape = [*batch, calib_size, pattern_size * coils] calib_source = calib_source.reshape(flat_shape) # [*B, N, C*P] # solve H = conjt(calib_source).matmul(calib_source) # [*B, C*P, C*P] diag(H).add_(lam * diag(H).abs().max(-1, keepdim=True).values) diag(H).add_(lam) g = conjt(calib_source).matmul(calib_target) # [*B, C*P, C] k = linalg.lmdiv(H, g).transpose(-1, -2) # [*B, C, C*P] k = k.reshape([*batch, coils, coils, pattern_size]) # [*B, C, C, P] kernels[pattern_code.item()] = k return kernels