def _get_default_space(affines, shapes, space=None, bbox=None): """Get default visualisation space Parameters ---------- affines : [sequence of] (4, 4) tensor_like shapes : [sequence of] (3,) tensor_like space : (4, 4) tensor_like, optional bbox : (2, 3) tensor_like, optional Returns ------- space, bbox """ affines, shapes = _get_default_native(affines, shapes) voxel_size = spatial.voxel_size(affines) voxel_size = voxel_size.min() if space is None: space = torch.eye(4) space[:-1, :-1] *= voxel_size voxel_size = spatial.voxel_size(space) if bbox is None: shapes = torch.as_tensor(shapes) mn, mx = spatial.compute_fov(space, affines, shapes) else: mn, mx = utils.as_tensor(bbox) mn /= voxel_size mx /= voxel_size return space, mn, mx
def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid
def _patch(patch, affine, shape, level): """Compute the patch size in voxels""" dim = affine.shape[-1] - 1 patch = py.make_list(patch) unit = 'pct' if isinstance(patch[-1], str): *patch, unit = patch patch = py.make_list(patch, dim) unit = unit.lower() if unit[0] == 'v': # voxels patch = [float(p) / 2**level for p in patch] elif unit in ('m', 'mm', 'cm', 'um'): # assume RAS orientation factor = (1e-3 if unit == 'um' else 1e1 if unit == 'cm' else 1e3 if unit == 'm' else 1) affine_ras = spatial.affine_reorient(affine, layout='RAS') vx_ras = spatial.voxel_size(affine_ras).tolist() patch = [factor * p / v for p, v in zip(patch, vx_ras)] patch = _ras_to_layout(patch, affine) elif unit[0] in 'p%': # percentage of shape patch = [0.01 * p * s for p, s in zip(patch, shape)] else: raise ValueError('Unknown patch unit:', unit) # round down to zero small patch sizes patch = [0 if p < 1e-3 else p for p in patch] return patch
def _precond(x, y, rho, sett): """Compute CG preconditioner. """ if len(x) != 1: raise ValueError( 'CG pre-conditioning only supports one repeat per contrast.') # Parameters n = 0 dm_y = y.dim lam = y.lam vx = voxel_size(y.mat).float() # tau*At(A(1)) M = x[n].tau * _proj_apply( 'AtA', torch.ones(dm_y, device=sett.device, dtype=torch.float32)[None, None, ...], x[n].po, method=sett.method, bound=sett.bound, interpolation=sett.interpolation) # + 2*rho*lam**2*sum(1/vx^2) (not lam*lam?) M += 2 * rho * lam**2 * vx.square().reciprocal().sum() M = M[0, 0, ...] # Return as lambda function precond = lambda x: x / M return precond
def _to_gradient_magnitudes(dat, mat, scl): """ Compute squared gradient magnitudes (modulated with scaling and voxel size). OBS: Replaces the image data in dat. Parameters ---------- dat : (X, Y, Z) tensor_like Image data. mat : (4, 4) tensor_like Affine matrix. scl : (N, ) tensor_like Gradient scaling parameter. Returns ---------- dat : (X, Y, Z) tensor_like Squared gradient magnitudes. """ # Get voxel size vx = voxel_size(mat) gr = scl*im_gradient(dat, vx=vx, which='forward', bound='zero') # Square gradients gr = torch.sum(gr**2, dim=0) dat = gr return dat
def _reset_origin(dat, mat, interpolation): """Reset affine matrix. Parameters ---------- dat : (X0, Y0, Z0) tensor_like, dtype=float32 Image data. mat : (4, 4) tensor_like, dtype=float64 Affine matrix. interpolation : int, default=1 (linear) Interpolation order. Returns ------- dat : (X1, Y1, Z1) tensor_like, dtype=float32 New image data. mat : (4, 4) tensor_like, dtype=float64 New affine matrix. """ device = dat.device # Reslice image data to world FOV dat, mat = _world_reslice(dat, mat, interpolation=interpolation) # Compute new, reset, affine matrix vx = voxel_size(mat) if mat[:3, :3].det() < 0: vx[0] = -vx[0] vx = vx.tolist() mat = affine_default(dat.shape, vx, dtype=torch.float64, device=device) return dat, mat
def get_kernel(kernel, affine, shape, level): """Convert the provided kernel size (RAS mm or pct) to native voxels""" dim = affine.shape[-1] - 1 kernel = py.make_list(kernel) unit = 'pct' if isinstance(kernel[-1], str): *kernel, unit = kernel kernel = py.make_list(kernel, dim) unit = unit.lower() if unit[0] == 'v': # voxels kernel = [p / 2**level for p in kernel] elif unit in ('m', 'mm', 'cm', 'um'): # assume RAS orientation factor = (1e-3 if unit == 'um' else 1e1 if unit == 'cm' else 1e3 if unit == 'm' else 1) affine_ras = spatial.affine_reorient(affine, layout='RAS') vx_ras = spatial.voxel_size(affine_ras).tolist() kernel = [factor * p / v for p, v in zip(kernel, vx_ras)] kernel = ras_to_layout(kernel, affine) elif unit[0] in 'p%': # percentage of shape kernel = [0.01 * p * s for p, s in zip(kernel, shape)] else: raise ValueError('Unknown patch unit:', unit) # ensure patch size is an integer >= 2 (else, no gradients) kernel = list(map(lambda x: max(int(pymath.ceil(x)), 2), kernel)) return kernel
def _world_reslice(dat, mat, interpolation=1, vx=None): """Reslice image data to world space. Parameters ---------- dat : (X0, Y0, Z0) tensor_like, dtype=float32 Image data. mat : (4, 4) tensor_like, dtype=float64 Affine matrix. interpolation : int, default=1 (linear) Interpolation order. vx : float | [float,] *3, optional Output voxel size. Returns ------- dat : (X1, Y1, Z1) tensor_like, dtype=float32 New image data. mat : (4, 4) tensor_like, dtype=float64 New affine matrix. """ device = dat.device # Get voxel size if vx is None: vx = voxel_size(mat).type(torch.float64).to(device) else: if not isinstance(vx, (list, tuple)): vx = (vx, ) * 3 vx = torch.as_tensor(vx).type(torch.float64).to(device) # Get corners c = _get_corners_3d(dat.shape).type(torch.float64).to(device) c = c.t() # Corners in world space c_world = mat[:3, :4].mm(c) c_world[0, :] = -c_world[0, :] # Get bounding box mx = c_world.max(dim=1)[0].round() mn = c_world.min(dim=1)[0].round() # Compute output affine mat_mn = affine_matrix_classic(mn).type(torch.float64).to(device) mat_vx = torch.diag( torch.cat((vx, torch.ones(1, dtype=torch.float64, device=device)))) mat_1 = affine_matrix_classic( -1 * torch.ones(3, dtype=torch.float64, device=device)) mat_out = mat_mn.mm(mat_vx.mm(mat_1)) # Comput output image dimensions dim_out = mat_out.inverse().mm( torch.cat((mx, torch.ones(1, dtype=torch.float64, device=device)))[:, None]) dim_out = dim_out[:3].ceil().flatten().int().tolist() I = torch.diag(torch.ones(4, dtype=torch.float64, device=device)) I[0, 0] = -I[0, 0] mat_out = I.mm(mat_out) # Compute mapping from output to input mat = mat_out.solve(mat)[0] # Reslice image data dat = _reslice_dat_3d(dat, mat, dim_out, interpolation=interpolation) return dat, mat_out
def _resample_inplane(x, sett): """Force in-plane resolution of observed data to be greater or equal to recon vx. """ if sett.force_inplane_res and sett.max_iter > 0: I = torch.eye(4, device=sett.device, dtype=torch.float64) for c in range(len(x)): for n in range(len(x[c])): # get image data dat = x[c][n].dat[None, None, ...] mat_x = x[c][n].mat dim_x = torch.as_tensor(x[c][n].dim, device=sett.device, dtype=torch.float64) vx_x = voxel_size(mat_x) # make grid D = I.clone() for i in range(3): D[i, i] = sett.vx / vx_x[i] if D[i, i] < 1.0: D[i, i] = 1 if float((I - D).abs().sum()) < 1e-4: continue mat_x = mat_x.matmul(D) dim_x = D[:3, :3].inverse().mm(dim_x[:, None]).floor().squeeze().cpu().int().tolist() grid = affine_grid(D.type(dat.dtype), dim_x) # resample dat = grid_pull(dat, grid[None, ...], bound='zero', extrapolate=False, interpolation=0) # do label if x[c][n].label is not None: x[c][n].label[0] = _warp_label(x[c][n].label[0], grid) # assign x[c][n].dat = dat[0, 0, ...] x[c][n].mat = mat_x x[c][n].dim = dim_x return x
def set_affine(header, affine, shape=None): if torch.is_tensor(affine): affine = affine.detach().cpu() affine = np.asanyarray(affine) vx = np.asanyarray(voxel_size(affine)) vx0 = header.get_zooms() vx = [vx[i] if i < len(vx) else vx0[i] for i in range(len(vx0))] header.set_zooms(vx) if isinstance(header, MGHHeader): if shape is None: warn('Cannot set the affine of a MGH file without ' 'knowing the data shape', RuntimeWarning) elif affine.shape not in ((3, 4), (4, 4)): raise ValueError('Expected a (3, 4) or (4, 4) affine matrix. ' 'Got {}'.format(affine.shape)) else: Mdc = affine[:3, :3] / vx shape = np.asarray(shape[:3]) c_ras = affine.dot(np.hstack((shape / 2.0, [1])))[:3] # Assign after we've had a chance to raise exceptions header['delta'] = vx header['Mdc'] = Mdc.T header['Pxyz_c'] = c_ras elif isinstance(header, Nifti1Header): header.set_sform(affine) header.set_qform(affine) elif isinstance(header, Spm99AnalyzeHeader): header.set_origin_from_affine(affine) else: warn('Format {} does not accept orientation matrices. ' 'It will be discarded.'.format(type(header).__name__), RuntimeWarning) return header
def _autoreg(argv=None): """Autograd Registration This is a command-line utility. """ # parse arguments argv = argv or list(sys.argv) options = parse(list(argv)) if not options: return # add a couple of defaults for trf in options.transformations: if isinstance(trf, struct.NonLinear) and not trf.losses: trf.losses.append(struct.AbsoluteLoss(factor=0.0001)) trf.losses.append(struct.MembraneLoss(factor=0.001)) trf.losses.append(struct.BendingLoss(factor=0.2)) trf.losses.append(struct.LinearElasticLoss(factor=(0.05, 0.2))) trf.losses = [collapse_losses(trf.losses)] if not options.optimizers: options.optimizers.append(struct.Adam()) options.propagate_defaults() options.read_info() options.propagate_filenames() if options.verbose >= 2: print(repr(options)) load_data(options) load_transforms(options) print('Losses:') for loss in options.losses: print(f' - {loss.name}') for f, m in zip(loss.fixed.dat, loss.moving.dat): print(f' -| {list(m[0].shape)}, {spatial.voxel_size(m[1]).tolist()}') print(f' -> {list(f[0].shape)}, {spatial.voxel_size(f[1]).tolist()}') print('Transforms') for trf in options.transformations: print(f' - {trf.name}') if isinstance(trf, struct.NonLinear): pyramid0 = trf.pyramid[-1] for pyramid in reversed(trf.pyramid): factor = 2**(pyramid0 - pyramid) shape = [s*factor for s in trf.dat.shape] vx = spatial.voxel_size(trf.affine) / factor print(f' - {list(shape)}, {vx.tolist()}') while not all_optimized(options): add_freedom(options) init_optimizers(options) optimize(options) free_data(options) detach_transforms(options) write_transforms(options) write_data(options)
def _nonlin_rls(maps, lam=1., norm='jtv'): """Update the (L1) weights. Parameters ---------- map : (P, *shape) ParameterMaps Parameter map lam : float or (P,) sequence[float], default=1 Regularisation factor norm : {'tv', 'jtv'}, default='jtv' Returns ------- rls : ([P], *shape) tensor Weights from the reweighted least squares scheme """ if norm not in ('tv', 'jtv', '__internal__'): return None if isinstance(maps, ParameterMap): # single map # this should only be an internal call # -> we return the squared gradient map assert norm == '__internal__' vx = spatial.voxel_size(maps.affine) grad_fwd = spatial.diff(maps.fdata(), dim=[0, 1, 2], voxel_size=vx, side='f') grad_bwd = spatial.diff(maps.fdata(), dim=[0, 1, 2], voxel_size=vx, side='b') grad = grad_fwd.square_().sum(-1) grad += grad_bwd.square_().sum(-1) grad *= lam / 2. # average across sides (2) return grad # multiple maps if norm == 'tv': rls = [] for map, l in zip(maps, lam): rls1 = _nonlin_rls(map, l, '__internal__') rls1 = rls1.sqrt_() rls.append(rls1) return torch.stack(rls, dim=0) else: assert norm == 'jtv' rls = 0 for map, l in zip(maps, lam): rls += _nonlin_rls(map, l, '__internal__') rls = rls.sqrt_() return rls
def upsample_vel(v, aff_in, aff_out, shape, readout): """ Upsample a 1D displacement field (by a potentially non-integer factor) using second order spline interpolation. Scales the displacement field appropriately to take into account the change of voxel size. """ if v.shape == shape: return v vx_down = spatial.voxel_size(aff_in) vx_down = vx_down[readout] vx_up = spatial.voxel_size(aff_out)[readout] factor = vx_down / vx_up v = spatial.reslice(v, aff_in, aff_out, shape, bound='dct2', interpolation=2, prefilter=True, extrapolate=True) v *= factor return v
def forward(self, x, affine=None): """ Parameters ---------- x : (X, Y, Z) tensor or str affine : (4, 4) tensor, optional Returns ------- seg : (32, oX, oY, oZ) tensor Segmentation resliced : (oX, oY, oZ) tensor Input resliced to 1 mm RAS affine : (4, 4) tensor Output orientation matrix """ if self.verbose: print('Preprocessing... ', end='', flush=True) if isinstance(x, str): x = io.map(x) if isinstance(x, io.MappedArray): if affine is None: affine = x.affine x = x.fdata() x = x.reshape(x.shape[:3]) x = SynthPreproc.addnoise(x) if affine is not None: affine, x = spatial.affine_reorient(affine, x, 'RAS') vx = spatial.voxel_size(affine) fwhm = 0.25 * vx.reciprocal() fwhm[vx > 1] = 0 x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3) x, affine = spatial.resize(x[None, None], vx.tolist(), affine=affine) x = x[0, 0] oshape = x.shape x, crop = SynthPreproc.crop(x) x = SynthPreproc.preproc(x)[None, None] if self.verbose: print('done.', flush=True) print('Segmenting... ', end='', flush=True) s, x = super().forward(x)[0], x[0, 0] if self.verbose: print('done.', flush=True) print('Postprocessing... ', end='', flush=True) s = self.relabel(s.argmax(0)) x = SynthPreproc.pad(x, oshape, crop) s = SynthPreproc.pad(s, oshape, crop) if self.verbose: print('done.', flush=True) return s, x, affine
def resize(self): affine, shape = spatial.affine_resize(self.affine0, self.shape0, 1 / (2**(self.level - 1))) scl0 = spatial.voxel_size(self.affine0).prod() scl = spatial.voxel_size(affine).prod() / scl0 self.lam_scale = scl for map in self.maps: map.volume = spatial.resize(map.volume[None, None, ...], shape=shape)[0, 0] map.affine = affine self.maps.affine = affine if self.rls is not None: if self.rls.dim() == len(shape): self.rls = spatial.resize(self.rls[None, None], hape=shape)[0, 0] else: self.rls = spatial.resize(self.rls[None], shape=shape)[0] self.nll['rls'] = self.rls.reciprocal().sum(dtype=torch.double)
def set_voxel_size(header, vx, shape=None): vx0 = header.get_zooms() nb_dim = max(len(vx0), len(vx)) vx = [vx[i] if i < len(vx) else vx0[i] for i in range(nb_dim)] header.set_zooms(vx) aff = torch.as_tensor(header.get_best_affine()) vx = torch.as_tensor(vx, dtype=aff.dtype, device=aff.device) vx0 = voxel_size(aff) aff[:-1,:] *= vx[:, None] / vx0[:, None] header = set_affine(header, aff, shape) return header
class SpatialTensor: """Base class for tensors with an orientation""" def __init__(self, dat, affine=None, dim=None, **backend): """ Parameters ---------- dat : ([C], *spatial) tensor affine : tensor, optional dim : int, default=`dat.dim() - 1` **backend : dtype, device """ if isinstance(dat, str): dat = io.map(dat)[None] if isinstance(dat, io.MappedArray): if affine is None: affine = dat.affine dat = dat.fdata(rand=True, **backend) self.dim = dim or dat.dim() - 1 self.dat = dat if affine is None: affine = spatial.affine_default(self.shape, **utils.backend(dat)) self.affine = affine.to(utils.backend(self.dat)['device']) def to(self, *args, **kwargs): return copy.copy(self).to_(*args, **kwargs) def to_(self, *args, **kwargs): self.dat = self.dat.to(*args, **kwargs) self.affine = self.affine.to(*args, **kwargs) return self voxel_size = property(lambda self: spatial.voxel_size(self.affine)) shape = property(lambda self: self.dat.shape[-self.dim:]) dtype = property(lambda self: self.dat.dtype) device = property(lambda self: self.dat.device) def _prm_as_str(self): s = [f'shape={list(self.shape)}'] v = [f'{vx:.2g}' for vx in self.voxel_size.tolist()] v = ', '.join(v) s += [f'voxel_size=[{v}]'] if self.dtype != torch.float32: s += [f'dtype={self.dtype}'] if self.device.type != 'cpu': s +=[f'device={self.device}'] return s def __repr__(self): s = ', '.join(self._prm_as_str()) s = f'{self.__class__.__name__}({s})' return s __str__ = __repr__
def resize(cls, x, affine, target_vx=1): target_vx = utils.make_vector(target_vx, x.dim(), **utils.backend(affine)) vx = spatial.voxel_size(affine) factor = vx / target_vx fwhm = 0.25 * factor.reciprocal() fwhm[factor > 1] = 0 x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3) x, affine = spatial.resize(x[None, None], factor.tolist(), affine=affine) x = x[0, 0] return x, affine
def downsample(x, aff_in, vx_out): """ Downsample an image (by an integer factor) to approximately match a target voxel size """ vx_in = spatial.voxel_size(aff_in) dim = len(vx_in) vx_out = utils.make_vector(vx_out, dim) factor = (vx_out / vx_in).clamp_min(1).floor().long() if (factor == 1).all(): return x, aff_in factor = factor.tolist() x, aff_out = spatial.pool(dim, x, factor, affine=aff_in) return x, aff_out
def space(self, value): self._space = value if torch.is_tensor(value): if value.shape != (4, 4): raise ValueError('Expected 4x4 matrix') self._space_matrix = value elif isinstance(value, int): affines = [image.affine for image in self.images] self._space_matrix = affines[value] else: if value is not None: raise ValueError('Expected a 4x4 matrix or an int or None') affines = [image.affine for image in self.images] voxel_size = spatial.voxel_size(utils.as_tensor(affines)) voxel_size = voxel_size.min() self._space_matrix = torch.eye(4) self._space_matrix[:-1, :-1] *= voxel_size
def load(self, x, affine=None): if isinstance(x, str): x = io.map(x) if isinstance(x, io.MappedArray): if affine is None: affine = x.affine x = x.fdata() x = x.reshape(x.shape[:3]) affine_original = affine x_original = x.shape if affine is not None: affine, x = spatial.affine_reorient(affine, x, 'RAS') vx = spatial.voxel_size(affine) x, affine = spatial.resize(x[None, None], vx.tolist(), affine=affine) x = x[0, 0] return x, affine, x_original, affine_original
def read_info(options): """Load affine transforms and space info of other volumes""" def read_file(fname): o = struct.FileWithInfo() o.fname = fname o.dir = os.path.dirname(fname) or '.' o.base = os.path.basename(fname) o.base, o.ext = os.path.splitext(o.base) if o.ext in ('.gz', '.bz2'): zext = o.ext o.base, o.ext = os.path.splitext(o.base) o.ext += zext f = io.volumes.map(fname) o.float = nitype(f.dtype).is_floating_point o.shape = squeeze_to_nd(f.shape, dim=3, channels=1) o.channels = o.shape[-1] o.shape = o.shape[:3] o.affine = f.affine.float() return o def read_affine(fname): mat = io.transforms.loadf(fname).float() return squeeze_to_nd(mat, 0, 2) def read_field(fname): f = io.volumes.map(fname) return f.affine.float(), f.shape[:3] options.files = [read_file(file) for file in options.files] for trf in options.transformations: if isinstance(trf, struct.Linear): trf.affine = read_affine(trf.file) else: trf.affine, trf.shape = read_field(trf.file) if options.target: options.target = read_file(options.target) if options.voxel_size: options.voxel_size = utils.make_vector( options.voxel_size, 3, dtype=options.target.affine.dtype) factor = spatial.voxel_size( options.target.affine) / options.voxel_size options.target.affine, options.target.shape = \ spatial.affine_resize(options.target.affine, options.target.shape, factor=factor, anchor='f')
def _smooth_for_reg(dat, mat, samp): """Smoothing for image registration. FWHM is computed from voxel size and sub-sampling amount. Parameters ---------- dat : (X, Y, Z) tensor_like 3D image volume. mat : (4, 4) tensor_like Affine matrix. samp : float Amount of sub-sampling (in mm). Returns ------- dat : (Nx, Ny, Nz) tensor_like Smoothed 3D image volume. """ if samp <= 0: return dat samp = torch.tensor((samp, ) * 3, dtype=dat.dtype, device=dat.device) # Make smoothing kernel vx = voxel_size(mat).to(dat.device).type(dat.dtype) fwhm = torch.sqrt( torch.max(samp**2 - vx**2, torch.zeros(3, device=dat.device, dtype=dat.dtype))) / vx smo = smooth(('gauss', ) * 3, fwhm=fwhm, device=dat.device, dtype=dat.dtype, sep=True) # Padding amount for subsequent convolution size_pad = (smo[0].shape[2], smo[1].shape[3], smo[2].shape[4]) size_pad = (torch.tensor(size_pad) - 1) // 2 size_pad = tuple(size_pad.int().tolist()) # Smooth deformation with Gaussian kernel (by separable convolution) dat = pad(dat, size_pad, side='both') dat = dat[None, None, ...] dat = F.conv3d(dat, smo[0]) dat = F.conv3d(dat, smo[1]) dat = F.conv3d(dat, smo[2])[0, 0, ...] return dat
def slice_to(self, stack, cache_result=False, recompute=True): aff = self.exp(cache_result=cache_result, recompute=recompute) if recompute or not hasattr(self, '_sliced'): aff = spatial.affine_matmul(aff, self.affine) aff_reorient = spatial.affine_reorient(self.affine, self.shape, stack.layout) aff = spatial.affine_lmdiv(aff_reorient, aff) aff = spatial.affine_grid(aff, self.shape) sliced = spatial.grid_pull(self.dat, aff, bound=self.bound, extrapolate=self.extrapolate) fwhm = [0] * self.dim fwhm[-1] = stack.slice_width / spatial.voxel_size(aff_reorient)[-1] sliced = spatial.smooth(sliced, fwhm, dim=self.dim, bound=self.bound) slices = [] for stack_slice in stack.slices: aff = spatial.affine_matmul(stack.affine, ) aff = spatial.affine_lmdiv(aff_reorient, ) if cache_result: self._sliced = sliced return sliced
def _estimate_hyperpar(x, sett): """ Estimate noise precision (tau) and mean brain intensity (mu) of each observed image. Args: x (_input()): Input data. Returns: tau (list): List of C torch.tensor(float) with noise precision of each MR image. lam (torch.tensor(float)): The parameter lambda (1, C). """ # Print info to screen t0 = _print_info('hyper_par', sett) # Total number of observations N = sum([len(xn) for xn in x]) # Do estimation cnt = 0 for c in range(len(x)): for n in range(len(x[c])): # Get data dat = x[c][n].dat if x[c][n].ct: # Estimate noise sd from estimate of FWHM sd_bg = estimate_fwhm(dat, voxel_size(x[c][n].mat), mn=20, mx=50)[1] mu_bg = torch.tensor(0.0, device=dat.device, dtype=dat.dtype) mu_fg = torch.tensor(4096, device=dat.device, dtype=dat.dtype) else: # Get noise and foreground statistics sd_bg, sd_fg, mu_bg, mu_fg = estimate_noise(dat, num_class=2, show_fit=sett.show_hyperpar, fig_num=100 + cnt) mu_bg = torch.tensor(0.0, device=dat.device, dtype=dat.dtype) # Set values x[c][n].sd = sd_bg.float() x[c][n].tau = 1 / sd_bg.float() ** 2 x[c][n].mu = torch.abs(mu_fg.float() - mu_bg.float()) cnt += 1 # Print info to screen _print_info('hyper_par', sett, x, t0) return x
def _crop_y(y, sett): """ Crop output images FOV to a fixed dimension Args: y (_output()): _output data. Returns: y (_output()): Cropped output data. """ if not sett.crop: return y device = sett.device # Output image information mat_y = y[0].mat vx_y = voxel_size(mat_y) # Define cropped FOV mat_mu, dim_mu = _bb_atlas('atlas_t1', fov=sett.fov, dtype=torch.float64, device=device) # Modulate atlas with voxel size mat_vx = torch.diag(torch.cat(( vx_y, torch.ones(1, dtype=torch.float64, device=device)))) mat_mu = mat_mu.mm(mat_vx) dim_mu = mat_vx[:3, :3].inverse().mm(dim_mu[:, None]).floor().squeeze() # Make output grid M = mat_mu.solve(mat_y)[0].type(y[0].dat.dtype) grid = affine_grid(M, dim_mu)[None, ...] # Crop for c in range(len(y)): y[c].dat = grid_pull(y[c].dat[None, None, ...], grid, bound='zero', extrapolate=False, interpolation=0)[0, 0, ...] # Do labels? if y[c].label is not None: y[c].label = grid_pull(y[c].label[None, None, ...], grid, bound='zero', extrapolate=False, interpolation=0)[0, 0, ...] y[c].mat = mat_mu y[c].dim = tuple(dim_mu.int().tolist()) return y
def _cli(args): """Command-line interface for `smooth` without exception handling""" args = args or sys.argv[1:] options = parser(args) if options.help: print(help) return fwhm = options.fwhm unit = 'mm' if isinstance(fwhm[-1], str): *fwhm, unit = fwhm fwhm = make_list(fwhm, 3) options.output = make_list(options.output, len(options.files)) for fname, ofname in zip(options.files, options.output): f = io.map(fname) vx = voxel_size(f.affine).tolist() dim = len(vx) if unit == 'mm': fwhm1 = [f / v for f, v in zip(fwhm, vx)] else: fwhm1 = fwhm[:len(vx)] dat = f.fdata() dat = movedim_front2back(dat, dim) dat = smooth(dat, type=options.method, fwhm=fwhm1, basis=options.basis, bound=options.padding, dim=dim) dat = movedim_back2front(dat, dim) folder, base, ext = fileparts(fname) ofname = ofname.format(dir=folder or '.', base=base, ext=ext, sep=os.path.sep) io.savef(dat, ofname, like=f)
def _compute_nll(x, y, sett, rho, sum_dtype=torch.float64): """ Compute negative model log-likelihood. Args: rho (torch.Tensor): ADMM step size. sum_dtype (torch.dtype): Defaults to torch.float64. Returns: nll_yx (torch.tensor()): Negative log-posterior nll_xy (torch.tensor()): Negative log-likelihood. nll_y (torch.tensor()): Negative log-prior. """ vx_y = voxel_size(y[0].mat).float() nll_xy = torch.tensor(0, device=sett.device, dtype=torch.float64) for c in range(len(x)): # Neg. log-likelihood term for n in range(len(x[c])): msk = x[c][n].dat != 0 Ay = _proj('A', y[c].dat, x[c], y[c], n=n, method=sett.method, do=sett.do_proj, bound=sett.bound, interpolation=sett.interpolation) nll_xy += 0.5 * x[c][n].tau * torch.sum( (x[c][n].dat[msk] - Ay[msk])**2, dtype=sum_dtype) # Neg. log-prior term Dy = y[c].lam * im_gradient( y[c].dat, vx=vx_y, bound=sett.bound, which=sett.diff) if c > 0: nll_y += torch.sum(Dy**2, dim=0) else: nll_y = torch.sum(Dy**2, dim=0) nll_y = torch.sum(torch.sqrt(nll_y), dtype=sum_dtype) return nll_xy + nll_y, nll_xy, nll_y
def _all_mat_dim_vx(x, sett): """ Get all images affine matrices, dimensions and voxel sizes (as numpy arrays). Returns: all_mat (torch.tensor): Image orientation matrices (N, 4, 4). Dim (torch.tensor): Image dimensions (N, 3). all_vx (torch.tensor): Image voxel sizes (N, 3). """ N = sum([len(xn) for xn in x]) all_mat = torch.zeros((N, 4, 4), device=sett.device, dtype=torch.float64) all_dim = torch.zeros((N, 3), device=sett.device, dtype=torch.float64) all_vx = torch.zeros((N, 3), device=sett.device, dtype=torch.float64) cnt = 0 for c in range(len(x)): for n in range(len(x[c])): all_mat[cnt, ...] = x[c][n].mat all_dim[cnt, ...] = torch.tensor(x[c][n].dim, device=sett.device, dtype=torch.float64) all_vx[cnt, ...] = voxel_size(x[c][n].mat) cnt += 1 return all_mat, all_dim, all_vx
def __repr__(self): vx = spatial.voxel_size(self.affine).tolist() return f'{type(self).__name__}(shape={self.shape}, vx={vx})'