def _world_reslice(dat, mat, interpolation=1, vx=None): """Reslice image data to world space. Parameters ---------- dat : (X0, Y0, Z0) tensor_like, dtype=float32 Image data. mat : (4, 4) tensor_like, dtype=float64 Affine matrix. interpolation : int, default=1 (linear) Interpolation order. vx : float | [float,] *3, optional Output voxel size. Returns ------- dat : (X1, Y1, Z1) tensor_like, dtype=float32 New image data. mat : (4, 4) tensor_like, dtype=float64 New affine matrix. """ device = dat.device # Get voxel size if vx is None: vx = voxel_size(mat).type(torch.float64).to(device) else: if not isinstance(vx, (list, tuple)): vx = (vx, ) * 3 vx = torch.as_tensor(vx).type(torch.float64).to(device) # Get corners c = _get_corners_3d(dat.shape).type(torch.float64).to(device) c = c.t() # Corners in world space c_world = mat[:3, :4].mm(c) c_world[0, :] = -c_world[0, :] # Get bounding box mx = c_world.max(dim=1)[0].round() mn = c_world.min(dim=1)[0].round() # Compute output affine mat_mn = affine_matrix_classic(mn).type(torch.float64).to(device) mat_vx = torch.diag( torch.cat((vx, torch.ones(1, dtype=torch.float64, device=device)))) mat_1 = affine_matrix_classic( -1 * torch.ones(3, dtype=torch.float64, device=device)) mat_out = mat_mn.mm(mat_vx.mm(mat_1)) # Comput output image dimensions dim_out = mat_out.inverse().mm( torch.cat((mx, torch.ones(1, dtype=torch.float64, device=device)))[:, None]) dim_out = dim_out[:3].ceil().flatten().int().tolist() I = torch.diag(torch.ones(4, dtype=torch.float64, device=device)) I[0, 0] = -I[0, 0] mat_out = I.mm(mat_out) # Compute mapping from output to input mat = mat_out.solve(mat)[0] # Reslice image data dat = _reslice_dat_3d(dat, mat, dim_out, interpolation=interpolation) return dat, mat_out
def _bb_atlas(name, fov, dtype=torch.float64, device='cpu'): """Bounding-box NITorch atlas data to specific field-of-view. Parameters ---------- name : str Name of nitorch data, available are: * atlas_t1: MRI T1w intensity atlas, 1 mm resolution. * atlas_t2: MRI T2w intensity atlas, 1 mm resolution. * atlas_pd: MRI PDw intensity atlas, 1 mm resolution. * atlas_t1_mni: MRI T1w intensity atlas, in MNI space, 1 mm resolution. * atlas_t2_mni: MRI T2w intensity atlas, in MNI space, 1 mm resolution. * atlas_pd_mni: MRI PDw intensity atlas, in MNI space, 1 mm resolution. fov : str Field-of-view, specific to 'name': * 'atlas_t1' | 'atlas_t2' | 'atlas_pd': * 'brain': Head FOV. * 'head': Brain FOV. Returns ---------- mat_mu : (4, 4) tensor, dtype=float64 Output affine matrix. dim_mu : (3, ) tensor, dtype=float64 Output dimensions. """ # Get atlas information file_mu = map(fetch_data(name)) dim_mu = file_mu.shape mat_mu = file_mu.affine.type(torch.float64).to(device) # Get bounding box o = [[0, 0, 0], [0, 0, 0]] if name in ['atlas_t1', 'atlas_t2', 'atlas_pd']: if fov == 'brain' or fov == 'head': o[0][0] = 18 o[0][1] = 52 o[0][2] = 120 o[1][0] = 18 o[1][1] = 48 o[1][2] = 58 if fov == 'head': o[0][2] = 25 # Get bounding box bb = torch.tensor( [[1 + o[0][0], 1 + o[0][1], 1 + o[0][2]], [dim_mu[0] - o[1][0], dim_mu[1] - o[1][1], dim_mu[2] - o[1][2]]]) bb = bb.type(torch.float64).to(device) # Output dimensions dim_mu = bb[1, ...] - bb[0, ...] + 1 # Bounding-box atlas affine mat_bb = affine_matrix_classic(bb[0, ...] - 1) # Modulate atlas affine with bb affine mat_mu = mat_mu.mm(mat_bb) return mat_mu, dim_mu
def _imatrix(M): """Return the parameters for creating an affine transformation matrix. Args: mat (torch.tensor): Affine transformation matrix (4, 4). Returns: P (torch.tensor): Affine parameters (<=12). Authors: John Ashburner & Stefan Kiebel, as part of the SPM12 software. """ device = M.device dtype = M.dtype one = torch.tensor(1.0, device=device, dtype=dtype) # Translations and Zooms R = M[:-1, :-1] C = cholesky(R.t().mm(R)) C = C.t() d = torch.diag(C) P = torch.tensor( [M[0, 3], M[1, 3], M[2, 3], 0, 0, 0, d[0], d[1], d[2], 0, 0, 0], device=device, dtype=dtype) if R.det() < 0: # Fix for -ve determinants P[6] = -P[6] # Shears C = lmdiv(torch.diag(torch.diag(C)), C) P[9] = C[0, 1] P[10] = C[0, 2] P[11] = C[1, 2] R0 = affine_matrix_classic( torch.tensor([0, 0, 0, 0, 0, 0, P[6], P[7], P[8], P[9], P[10], P[11]])).to(device) R0 = R0[:-1, :-1] R1 = R.mm(R0.inverse()) # This just leaves rotations in matrix R1 # Correct rounding errors rang = lambda x: torch.min(torch.max(x, -one), one) P[4] = torch.asin(rang(R1[0, 2])) if (torch.abs(P[4]) - pi / 2)**2 < 1e-9: P[3] = 0 P[5] = torch.atan2(-rang(R1[1, 0]), rang(-R1[2, 0] / R1[0, 2])) else: c = torch.cos(P[4]) P[3] = torch.atan2(rang(R1[1, 2] / c), rang(R1[2, 2] / c)) P[5] = torch.atan2(rang(R1[0, 1] / c), rang(R1[0, 0] / c)) return P
def _subvol(dat, mat, bb=None): """Extract a sub-volume. Parameters ---------- dat : (X0, Y0, Z0) tensor_like Image volume. mat : (4, 4) tensor_like, dtype=float64 Image affine matrix. bb : (2, 3) sequence, optional Bounding box. Returns ---------- dat : (X1, Y1, Z1) tensor_like Image sub-volume. mat : (4, 4) tensor_like, dtype=float64 Sub-volume affine matrix. """ device = dat.device dim_in = dat.shape if bb is None: bb = torch.tensor([[1, 1, 1], dim_in], dtype=torch.float64, device=device) # Process bounding-box bb = bb.round() bb = bb.sort(dim=0)[0] bb[0, ...] = torch.max(bb[0, ...], torch.ones(3, device=device, dtype=torch.float64)) bb[1, ...] = torch.min( bb[1, ...], torch.tensor(dim_in, device=device, dtype=torch.float64)) # Output dimensions dim_bb = bb[1, ...] - bb[0, ...] + 1 # Bounding-box affine mat_bb = affine_matrix_classic(bb[0, ...] - 1) # mat_bb = matrix(bb[0, ...] - 1) # Output data dat = _reslice_dat_3d(dat, mat_bb, dim_bb, interpolation='nearest', bound='zero', extrapolate=False) # Output affine mat = mat.mm(mat_bb) return dat, mat
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size Other Parameters ---------------- dim : int, optional device : torch.device, optional dtype : torch.dtype, optional Returns ------- affine : (batch, dim+1, dim+1) tensor Affine matrix """ dim = overload.get('dim', self.dim) dtype = overload.get('dtype', self.dtype) device = overload.get('device', self.device) backend = dict(dtype=dtype, device=device) # prepare sampler translation = self._make_sampler('translation', dim, **backend) rotation = self._make_sampler('rotation', dim, **backend) zoom = self._make_sampler('zoom', dim, **backend) shear = self._make_sampler('shear', dim, **backend) # sample parameters prm = torch.cat([ translation.sample([batch]), rotation.sample([batch]).mul_(math.pi/180), zoom.sample([batch]), shear.sample([batch]), ], dim=-1) # generate affine matrix mat = affine_matrix_classic(prm, dim=dim) return mat
def _format_y(x, sett): """ Construct algorithm output struct. See _output() dataclass. Returns: y (_output()): Algorithm output struct(s). """ one = torch.tensor(1.0, device=sett.device, dtype=torch.float64) vx_y = sett.vx if vx_y == 0: vx_y = None if vx_y is not None: if isinstance(vx_y, int): vx_y = float(vx_y) if isinstance(vx_y, float): vx_y = (vx_y,) * 3 vx_y = torch.tensor(vx_y, dtype=torch.float64, device=sett.device) # Get all orientation matrices and dimensions all_mat, all_dim, all_vx = _all_mat_dim_vx(x, sett) N = all_mat.shape[0] # Total number of observations if N == 1: # Disable unified rigid registration sett.unified_rigid = False sett.clean_fov = True # Check if all input images have the same fov/vx mat_same = True dim_same = True vx_same = True for n in range(1, N): mat_same = mat_same & \ torch.equal(round(all_mat[n - 1, ...], 3), round(all_mat[n, ...], 3)) dim_same = dim_same & \ torch.equal(round(all_dim[n - 1, ...], 3), round(all_dim[n, ...], 3)) vx_same = vx_same & \ torch.equal(round(all_vx[n - 1, ...], 3), round(all_vx[n, ...], 3)) # Decide if super-resolving and/or projection is necessary do_sr = True sett.do_proj = True if vx_y is None and ((N == 1) or vx_same): # One image, voxel size not given vx_y = all_vx[0, ...] if vx_same and (torch.abs(all_vx[0, ...] - vx_y) < 1e-3).all(): # All input images have same voxel size, and output voxel size is the also the same do_sr = False if mat_same and dim_same and not sett.unified_rigid: # All input images have the same FOV mat = all_mat[0, ...] dim = all_dim[0, ...] sett.do_proj = False if do_sr or sett.do_proj: # Get FOV of mean space mat, dim, vx_y = _mean_space(all_mat, all_dim, vx_y) if sett.crop: # Crop output to atlas field-of-view vx_y = voxel_size(mat) mat_mu, dim = _bb_atlas('atlas_t1', fov=sett.fov, dtype=torch.float64, device=sett.device) # Modulate atlas with voxel size mat_vx = torch.diag(torch.cat(( vx_y, torch.ones(1, dtype=torch.float64, device=sett.device)))) mat = mat_mu.mm(mat_vx) dim = mat_vx[:3, :3].inverse().mm(dim[:, None]).floor().squeeze() if sett.pow: # Ensure output image dimensions are compatible with encode/decode # architecture dim2 = ceil_pow(dim, p=2.0, l=2.0, mx=256) dim3 = ceil_pow(dim, p=2.0, l=3.0, mx=256) ndim = dim2 ndim[dim3 < ndim] = dim3[dim3 < ndim] # Modulate output affine mat_bb = affine_matrix_classic(-((ndim - dim)/2).round())\ .type(torch.float64).to(sett.device) mat = mat.mm(mat_bb) dim = ndim # Set method if do_sr: sett.method = 'super-resolution' else: sett.method = 'denoising' # Optimise even/odd scaling parameter? if sett.method == 'denoising' or (N == 1 and x[0][0].ct): sett.scaling = False dim = tuple(dim.int().tolist()) _ = _print_info('mean-space', sett, dim, mat) # Assign output y = [] for c in range(len(x)): y.append(_output()) # Regularisation (lambda) for channel c mu_c = torch.zeros(len(x[c]), dtype=torch.float32, device=sett.device) for n in range(len(x[c])): mu_c[n] = x[c][n].mu if x[c][n].ct and sett.method == 'super-resolution': mu_c[n] /= 4 y[c].lam0 = math.sqrt(1/len(x)) / torch.mean(mu_c) y[c].lam = math.sqrt(1/len(x)) / torch.mean(mu_c) # To facilitate rescaling # Output image(s) dimension and orientation matrix y[c].dim = dim y[c].mat = mat.double().to(sett.device) return y, sett
def forward(self, batch=1, **overload): """ Parameters ---------- batch : int, default=1 Batch size overload : dict All parameters defined at build time can be overridden at call time Returns ------- affine : (batch, dim[+1], dim+1) tensor Velocity field """ dim = overload.get('dim', self.dim) translation = make_list(overload.get('translation', self.translation)) rotation = make_list(overload.get('rotation', self.rotation)) zoom = make_list(overload.get('zoom', self.zoom)) shear = make_list(overload.get('shear', self.shear)) dtype = make_list(overload.get('dtype', self.dtype)) device = make_list(overload.get('device', self.device)) # compute dimension dim = dim or max(len(translation), len(rotation), len(zoom), len(shear)) translation = make_list(translation, dim) rotation = make_list(rotation, dim * (dim - 1) // 2) zoom = make_list(zoom, dim) shear = make_list(shear, dim * (dim - 1) // 2) # sample values if needed translation = [ x([batch]) if callable(x) else self.default_translation([batch]) if x is True else 0. if x is None or x is False else x for x in translation ] rotation = [ x([batch]) if callable(x) else self.default_rotation([batch]) if x is True else 0. if x is None or x is False else x for x in rotation ] zoom = [ x([batch]) if callable(x) else self.default_zoom([batch]) if x is True else 1. if x is None or x is False else x for x in zoom ] shear = [ x([batch]) if callable(x) else self.default_shear([batch]) if x is True else 0. if x is None or x is False else x for x in shear ] rotation = [x * math.pi / 180 for x in rotation] # degree -> radian prm = [*translation, *rotation, *zoom, *shear] prm = [ p.expand(batch) if torch.is_tensor(p) and p.shape[0] != batch else make_list(p, batch) if not torch.is_tensor(p) else p for p in prm ] prm = utils.as_tensor(prm) prm = prm.transpose(0, 1) # generate affine matrix mat = affine_matrix_classic(prm, dim=dim).\ type(self.dtype).to(self.device) return mat
def forward(self, prm, **overload): """ Parameters ---------- prm : (batch, nb_prm) tensor or list[tensor] Affine parameters, ordered as (*translations, *rotations, *zooms, *shears). overload : dict All parameters of the module can be overridden at call time. Returns ------- affine : (batch, dim+1, dim+1) tensor Affine matrix """ dim = overload.get('dim', self.dim) basis = overload.get('basis', self.basis) logzooms = overload.get('logzooms', self.logzooms) def checkdim(expected, got): if got != expected: raise ValueError('Expected {} parameters for group {}({}) but ' 'got {}.'.format(expected, basis, dim, got)) nb_prm = prm.shape[-1] eps = core.constants.eps(prm.dtype) if basis == 'T': checkdim(dim, nb_prm) elif basis == 'SO': checkdim(dim*(dim-1)//2, nb_prm) elif basis == 'SE': checkdim(dim + dim*(dim-1)//2, nb_prm) elif basis == 'D': checkdim(dim + 1, nb_prm) translations = prm[..., :dim] zooms = prm[..., -1] zooms = zooms.expand([*zooms.shape, dim]) zooms = zooms.exp() if logzooms else zooms.clamp_min(eps) prm = torch.cat((translations, zooms), dim=-1) elif basis == 'CSO': checkdim(dim + dim*(dim-1)//2 + 1, nb_prm) rigid = prm[..., :-1] zooms = prm[..., -1] zooms = zooms.expand([*zooms.shape, dim]) zooms = zooms.exp() if logzooms else zooms.clamp_min(eps) prm = torch.cat((rigid, zooms), dim=-1) elif basis == 'GL+': checkdim((dim-1)*(dim+1), nb_prm) rigid = prm[..., :dim*(dim-1)//2] zooms = prm[..., dim*(dim-1)//2:(dim + dim*(dim-1)//2)] zooms = zooms.exp() if logzooms else zooms.clamp_min(eps) strides = prm[..., (dim + dim*(dim-1)//2):] prm = torch.cat((rigid, zooms, strides), dim=-1) elif basis == 'Aff+': checkdim(dim*(dim+1), nb_prm) rigid = prm[..., :(dim + dim*(dim-1)//2)] zooms = prm[..., (dim + dim*(dim-1)//2):(2*dim + dim*(dim-1)//2)] zooms = zooms.exp() if logzooms else zooms.clamp_min(eps) strides = prm[..., (2*dim + dim*(dim-1)//2):] prm = torch.cat((rigid, zooms, strides), dim=-1) else: raise ValueError(f'Unknown basis {basis}') return spatial.affine_matrix_classic(prm, dim=dim)
def forward(self, prm): """ Parameters ---------- prm : (batch, nb_prm) tensor or list[tensor] Affine parameters, ordered as (*translations, *rotations, *zooms, *shears). Returns ------- affine : (batch, dim+1, dim+1) tensor Affine matrix """ def checkdim(expected, got): if got != expected: raise ValueError(f'Expected {expected} parameters for ' f'group {self.basis}({self.dim}) but ' f'got {got}.') nb_prm = prm.shape[-1] eps = core.constants.eps(prm.dtype) if self.basis == 'T': checkdim(self.dim, nb_prm) elif self.basis == 'SO': checkdim(self.dim * (self.dim - 1) // 2, nb_prm) elif self.basis == 'SE': checkdim(self.dim + self.dim * (self.dim - 1) // 2, nb_prm) elif self.basis == 'D': checkdim(self.dim + 1, nb_prm) translations = prm[..., :self.dim] zooms = prm[..., -1] zooms = zooms.expand([*zooms.shape, self.dim]) zooms = zooms.exp() if self.logzooms else zooms.clamp_min(eps) prm = torch.cat((translations, zooms), dim=-1) elif self.basis == 'CSO': checkdim(self.dim + self.dim * (self.dim - 1) // 2 + 1, nb_prm) rigid = prm[..., :-1] zooms = prm[..., -1] zooms = zooms.expand([*zooms.shape, self.dim]) zooms = zooms.exp() if self.logzooms else zooms.clamp_min(eps) prm = torch.cat((rigid, zooms), dim=-1) elif self.basis == 'GL+': checkdim((self.dim - 1) * (self.dim + 1), nb_prm) rigid = prm[..., :self.dim * (self.dim - 1) // 2] zooms = prm[..., self.dim * (self.dim - 1) // 2:(self.dim + self.dim * (self.dim - 1) // 2)] zooms = zooms.exp() if self.logzooms else zooms.clamp_min(eps) strides = prm[..., (self.dim + self.dim * (self.dim - 1) // 2):] prm = torch.cat((rigid, zooms, strides), dim=-1) elif self.basis == 'Aff+': checkdim(self.dim * (self.dim + 1), nb_prm) rigid = prm[..., :(self.dim + self.dim * (self.dim - 1) // 2)] zooms = prm[..., (self.dim + self.dim * (self.dim - 1) // 2):(2 * self.dim + self.dim * (self.dim - 1) // 2)] zooms = zooms.exp() if self.logzooms else zooms.clamp_min(eps) strides = prm[..., (2 * self.dim + self.dim * (self.dim - 1) // 2):] prm = torch.cat((rigid, zooms, strides), dim=-1) else: raise ValueError(f'Unknown basis {self.basis}') return spatial.affine_matrix_classic(prm, dim=self.dim)