def _get_image_data(pths, device=None, dtype=None): """Formats image data (on disk) to tensor compatible with the denoise_mri function OBS: Assumes that all input images are 3D volumes with the same dimensions. Parameters ---------- pths : [nchannels] sequence Paths to image data device : torch.device, optional Torch device dtype : torch.dtype, optional Torch data type Returns ---------- dat : (dmx, dmy, dmz, nchannels) tensor Image tensor """ for i, p in enumerate(pths): # read nifti nii = map(p) # get data if i == 0: dat = nii.fdata(dtype=dtype, device=device, rand=False)[..., None] else: dat = torch.cat( (dat, nii.fdata(dtype=dtype, device=device, rand=False)[..., None]), dim=-1) return dat
def _map_image(self): image = self.image self._map = None self._fdata = None self._affine = None self._shape = None if isinstance(image, str): self._map = io.map(image) else: self._map = None if self._map is None: if not isinstance(image, (list, tuple)): image = [image] if len(image) < 2: image = [*image, None, None] dat, aff, *_ = image dat = torch.as_tensor(dat) if aff is None: aff = spatial.affine_default(dat.shape[-3:]) self._fdata = dat self._affine = aff self._shape = tuple(dat.shape[-3:]) else: self._affine = self._map.affine self._shape = tuple(self._map.shape[-3:])
def get_spm_prior(**backend): fname = path_spm_prior() f = io.map(fname).movedim(-1, 0)[:-1] # drop background aff = f.affine dat = f.fdata(**backend) aff = aff.to(**utils.backend(dat)) return dat, aff
def _map_image(fnames, dim=None): """ Load a N-D image from disk. Returns: image : (C, *spatial) MappedTensor affine: (D+1, D+1) tensor """ affine = None imgs = [] for fname in fnames: img = io.map(fname) if affine is None: affine = img.affine if dim is None: dim = img.affine.shape[-1] - 1 # img = img.fdata(rand=True, device=device) if img.dim > dim: img = img.movedim(-1, 0) else: img = img[None] img = img.unsqueeze(-1, dim + 1 - img.dim) if img.dim > dim + 1: raise ValueError(f'Don\'t know how to deal with an image of ' f'shape {tuple(img.shape)}') imgs.append(img) del img imgs = io.cat(imgs, dim=0) return imgs, affine
def _format_input(img, device='cpu', rand=False, cutoff=None): """Format preprocessing input data. """ if isinstance(img, str): img = [img] if isinstance(img, list) and isinstance(img[0], torch.Tensor): img = [img] file = [] dat = [] mat = [] for n in range(len(img)): if isinstance(img[n], str): # Input are nitorch.io compatible paths file.append(map(img[n])) dat.append(file[n].fdata(dtype=torch.float32, device=device, rand=rand, cutoff=cutoff)) mat.append(file[n].affine.to(dtype=torch.float64, device=device)) else: # Input are tensors (clone so to not modify input data) file.append(None) dat.append(img[n][0].clone().to(dtype=torch.float32, device=device)) mat.append(img[n][1].clone().to(dtype=torch.float64, device=device)) return dat, mat, file
def set_dat(self, dat, affine=None, **backend): if isinstance(dat, str): dat = io.map(dat) if affine is None: affine = self.dat.affine if isinstance(dat, SpatialTensor) and affine is None: affine = dat.affine self.dat = Displacement.make(dat, affine=affine, **backend) return self
def _bb_atlas(name, fov, dtype=torch.float64, device='cpu'): """Bounding-box NITorch atlas data to specific field-of-view. Parameters ---------- name : str Name of nitorch data, available are: * atlas_t1: MRI T1w intensity atlas, 1 mm resolution. * atlas_t2: MRI T2w intensity atlas, 1 mm resolution. * atlas_pd: MRI PDw intensity atlas, 1 mm resolution. * atlas_t1_mni: MRI T1w intensity atlas, in MNI space, 1 mm resolution. * atlas_t2_mni: MRI T2w intensity atlas, in MNI space, 1 mm resolution. * atlas_pd_mni: MRI PDw intensity atlas, in MNI space, 1 mm resolution. fov : str Field-of-view, specific to 'name': * 'atlas_t1' | 'atlas_t2' | 'atlas_pd': * 'brain': Head FOV. * 'head': Brain FOV. Returns ---------- mat_mu : (4, 4) tensor, dtype=float64 Output affine matrix. dim_mu : (3, ) tensor, dtype=float64 Output dimensions. """ # Get atlas information file_mu = map(fetch_data(name)) dim_mu = file_mu.shape mat_mu = file_mu.affine.type(torch.float64).to(device) # Get bounding box o = [[0, 0, 0], [0, 0, 0]] if name in ['atlas_t1', 'atlas_t2', 'atlas_pd']: if fov == 'brain' or fov == 'head': o[0][0] = 18 o[0][1] = 52 o[0][2] = 120 o[1][0] = 18 o[1][1] = 48 o[1][2] = 58 if fov == 'head': o[0][2] = 25 # Get bounding box bb = torch.tensor( [[1 + o[0][0], 1 + o[0][1], 1 + o[0][2]], [dim_mu[0] - o[1][0], dim_mu[1] - o[1][1], dim_mu[2] - o[1][2]]]) bb = bb.type(torch.float64).to(device) # Output dimensions dim_mu = bb[1, ...] - bb[0, ...] + 1 # Bounding-box atlas affine mat_bb = affine_matrix_classic(bb[0, ...] - 1) # Modulate atlas affine with bb affine mat_mu = mat_mu.mm(mat_bb) return mat_mu, dim_mu
def get_spm_prior(**backend): url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii' fname = os.path.join(cache_dir, 'SPM12_TPM.nii') if not os.path.exists(fname): os.makedirs(cache_dir, exist_ok=True) fname = download(url, fname) f = io.map(fname).movedim(-1, 0) #[:-1] # drop background aff = f.affine dat = f.fdata(**backend) aff = aff.to(**utils.backend(dat)) return dat, aff
def forward(self, x, affine=None): """ Parameters ---------- x : (X, Y, Z) tensor or str affine : (4, 4) tensor, optional Returns ------- seg : (32, oX, oY, oZ) tensor Segmentation resliced : (oX, oY, oZ) tensor Input resliced to 1 mm RAS affine : (4, 4) tensor Output orientation matrix """ if self.verbose: print('Preprocessing... ', end='', flush=True) if isinstance(x, str): x = io.map(x) if isinstance(x, io.MappedArray): if affine is None: affine = x.affine x = x.fdata() x = x.reshape(x.shape[:3]) x = SynthPreproc.addnoise(x) if affine is not None: affine, x = spatial.affine_reorient(affine, x, 'RAS') vx = spatial.voxel_size(affine) fwhm = 0.25 * vx.reciprocal() fwhm[vx > 1] = 0 x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3) x, affine = spatial.resize(x[None, None], vx.tolist(), affine=affine) x = x[0, 0] oshape = x.shape x, crop = SynthPreproc.crop(x) x = SynthPreproc.preproc(x)[None, None] if self.verbose: print('done.', flush=True) print('Segmenting... ', end='', flush=True) s, x = super().forward(x)[0], x[0, 0] if self.verbose: print('done.', flush=True) print('Postprocessing... ', end='', flush=True) s = self.relabel(s.argmax(0)) x = SynthPreproc.pad(x, oshape, crop) s = SynthPreproc.pad(s, oshape, crop) if self.verbose: print('done.', flush=True) return s, x, affine
def get_data(x, w, affine, dim, **backend): if not torch.is_tensor(x): if isinstance(x, str): f = io.map(x) if affine is None: affine = f.affine if f.dim > dim: if f.shape[dim] == 1: f = f.squeeze(dim) if f.dim > dim + 1: raise ValueError('Too many dimensions') if f.dim > dim: f = f.movedim(-1, 0) else: f = f[None] x = f.fdata(**backend, rand=True, missing=0) else: f = io.stack([io.map(x1) for x1 in x]) if affine is None: affine = f.affine[0] x = f.fdata(**backend, rand=True, missing=0) if x.dim() > dim + 1: x = x.unsqeeze(-1) if x.dim() > dim + 1: raise ValueError('Too many dimensions') if x.dim() == dim: x = x[None] if not torch.is_tensor(w) and w is not None: w = io.loadf(w, **backend) if x.dim() > dim: w = w.squeeze(-1) if x.dim() > dim: raise ValueError('Too many dimensions') x = x.contiguous() if w is not None: w = w.contiguous() return x, w, affine
def get_prior(prior, affine_prior, **backend): if prior is None: prior, _affine_prior = get_spm_prior(**backend) if affine_prior is None: affine_prior = _affine_prior elif isinstance(prior, str): prior = io.map(prior).movedim(-1, 0) if affine_prior is None: affine_prior = prior.affine prior = prior.fdata(**backend) else: prior = prior.to(**backend) return prior, affine_prior
def _read_label(x, pth, sett): """Read labels and add to input struct. """ # Load labels file = map(pth) dat = file.fdata(dtype=torch.float32, device=sett.device) # Sanity check if not torch.equal(torch.as_tensor(x.dim), torch.as_tensor(dat.shape)): raise ValueError('Incorrect label dimensions.') # Append labels x.label = [dat, file] return x
def _init_from_fname(new, fnames, permission='r', keep_open=False, **attributes): fnames = py.make_list(fnames) fs = [] for fname in fnames: f = io.map(fname, permission=permission, keep_open=keep_open) while f.dim < 4: f = f.unsqueeze(-1) fs += [f] fs = io.cat(fs, -1).permute([-1, 0, 1, 2]) new._init_from_mapped(fs, **attributes)
def from_fname(cls, fname, permission='r', keep_open=False, **attributes): """Build an MRI object from a file name. We accept paths of the form 'path/to/file.nii,1,2', which mean that only the subvolume `[:, :, :, 1, 2]` should be read. The first three (spatial) dimensions are always read. """ fname = str(fname) fname, *index = fname.split(',') mapped = io.map(fname, permission=permission, keep_open=keep_open) if index: index = tuple(int(i) for i in index) index = (slice(None), ) * 3 + index mapped = mapped[index] return cls.from_mapped(mapped, **attributes)
def prepare_one(inp): if isinstance(inp, (list, tuple)): has_aff = len(inp) > 1 if has_aff: aff0 = inp[1] inp, aff = prepare_one(inp[0]) if has_aff: aff = aff0 return [inp, aff] if isinstance(inp, str): inp = io.map(inp)[None, None] if isinstance(inp, io.MappedArray): return inp.fdata(rand=True), inp.affine[None] inp = torch.as_tensor(inp) aff = spatial.affine_default(inp.shape)[None] return [inp, aff]
def load(self, x, affine=None): if isinstance(x, str): x = io.map(x) if isinstance(x, io.MappedArray): if affine is None: affine = x.affine x = x.fdata() x = x.reshape(x.shape[:3]) affine_original = affine x_original = x.shape if affine is not None: affine, x = spatial.affine_reorient(affine, x, 'RAS') vx = spatial.voxel_size(affine) x, affine = spatial.resize(x[None, None], vx.tolist(), affine=affine) x = x[0, 0] return x, affine, x_original, affine_original
def do_apply(fnames, phi, jac): """Correct files with a given polarity""" for fname in fnames: dir, base, ext = py.fileparts(fname) ofname = options.output ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext) if options.verbose: print(f'unwarp {fname} \n' f' -> {ofname}') f = io.map(fname) d = f.fdata(device=device) d = utils.movedim(d, readout, -1) d = _deform1d(d, phi) if jac is not None: d *= jac d = utils.movedim(d, -1, readout) io.savef(d, ofname, like=fname)
def __init__(self, dat, affine=None, dim=None, **backend): """ Parameters ---------- dat : ([C], *spatial) tensor affine : tensor, optional dim : int, default=`dat.dim() - 1` **backend : dtype, device """ if isinstance(dat, str): dat = io.map(dat)[None] if isinstance(dat, io.MappedArray): if affine is None: affine = dat.affine dat = dat.fdata(rand=True, **backend) self.dim = dim or dat.dim() - 1 self.dat = dat if affine is None: affine = spatial.affine_default(self.shape, **utils.backend(dat)) self.affine = affine.to(utils.backend(self.dat)['device'])
def _cli(args): """Command-line interface for `smooth` without exception handling""" args = args or sys.argv[1:] options = parser(args) if options.help: print(help) return fwhm = options.fwhm unit = 'mm' if isinstance(fwhm[-1], str): *fwhm, unit = fwhm fwhm = make_list(fwhm, 3) options.output = make_list(options.output, len(options.files)) for fname, ofname in zip(options.files, options.output): f = io.map(fname) vx = voxel_size(f.affine).tolist() dim = len(vx) if unit == 'mm': fwhm1 = [f / v for f, v in zip(fwhm, vx)] else: fwhm1 = fwhm[:len(vx)] dat = f.fdata() dat = movedim_front2back(dat, dim) dat = smooth(dat, type=options.method, fwhm=fwhm1, basis=options.basis, bound=options.padding, dim=dim) dat = movedim_back2front(dat, dim) folder, base, ext = fileparts(fname) ofname = ofname.format(dir=folder or '.', base=base, ext=ext, sep=os.path.sep) io.savef(dat, ofname, like=f)
def _get_image_data(pths, device=None, dtype=None): """Formats image data (on disk) to tensor compatible with the denoise_mri function OBS: Assumes that all input images are 3D volumes with the same dimensions. Parameters ---------- pths : [nchannels] sequence Paths to image data device : torch.device, optional Torch device dtype : torch.dtype, optional Torch data type Returns ---------- dat : (nchannels, dmx, dmy, dmz) tensor Image tensor affine : (4, 4) tensor Affine matrix (assumed the same across channels) nii : nitorch.BabelArray BabelArray object. """ for i, p in enumerate(pths): # read nifti nii = map(p) # get data if i == 0: dat = nii.fdata(dtype=dtype, device=device, rand=False)[None] affine = nii.affine.type(dat.dtype).to(dat.device) shape = nii.shape else: if not torch.equal(torch.as_tensor(shape), torch.as_tensor(nii.shape)): raise ValueError('All images must have same dimensions!') dat = torch.cat( (dat, nii.fdata(dtype=dtype, device=device, rand=False)[None]), dim=0) return dat, affine, nii
def estimate_noise(dat, show_fit=False, fig_num=1, num_class=2, mu_noise=None, max_iter=10000, verbose=0, bins=1024, chi=False): """Estimate the noise distribution in an image by fitting either a Gaussian, Rician or noncentral Chi mixture model to the image's intensity histogram. The Gaussian model is only used if negative values are found in the image (e.g., if it is a CT scan). Parameters ---------- dat : str or tensor Tensor or path to nifti file. show_fit : bool, default=False Show a plot of the histogram fit at the end fig_num : int, default=1 ID of matplotlib figure to use num_class : int, default=2 Number of mixture classes (only for GMM). mu_noise : float, optional Mean of noise class. If provided, the class with mean value closest to `mu_noise` is assumed to be the background class. Otherwise, it is the class with smallest standard deviation. max_iter : int, default=10000 Maximum number of EM iterations. verbose int, defualt=0: Display progress. Defaults to 0. * 0: None. * 1: Print summary when finished. * 2: 1 + Log-likelihood plot. * 3: 1 + 2 + print convergence. bins : int, default=1024 Number of histogram bins. chi : bool, default=False Fit a noncentral Chi rather than a Rice model. Returns ------- prm_noise : dict Parameters of the distribution of the background (noise) class With fields 'sd', 'mean' and (if `chi`) 'dof' prm_not_noise : dict Parameters of the distribution of the foreground (tissue) class With fields 'sd', 'mean' and (if `chi`) 'dof' """ DTYPE = torch.double # use double for accuracy (maybe single would work?) slope = None if isinstance(dat, str): dat = io.map(dat) if isinstance(dat, io.MappedArray): slope = dat.slope if not slope and not dtype_info(dat.dtype).if_floating_point: slope = 1 dat = dat.fdata(rand=True, missing=0, dtype=DTYPE) dat = torch.as_tensor(dat, dtype=DTYPE).flatten() device = dat.device if not slope and not dat.dtype.is_floating_point: slope = 1 # exclude missing values dat = dat[torch.isfinite(dat)] dat = dat[dat != 0] # Mask and get min/max mn = dat.min() mx = dat.max() dat = dat[dat != mn] dat = dat[dat != mx] mn = mn.round() mx = mx.round() if slope: # ensure bin width aligns with integer width width = (mx - mn) / bins width = (width / slope).ceil() * slope mx = mn + bins * width # Histogram bin data dat = torch.histc(dat, bins=bins, min=mn, max=mx).to(DTYPE) x = torch.linspace(mn, mx, steps=bins, device=device, dtype=DTYPE) # fit mixture model if mn < 0: # Make GMM model model = GMM(num_class=num_class) elif chi: model = CMM(num_class=num_class) else: # Make RMM model model = RMM(num_class=num_class) # Fit GMM/RMM/CMM using Numpy model.fit(x, W=dat, verbose=verbose, max_iter=max_iter, show_fit=show_fit, fig_num=fig_num) # Get means and mixing proportions mu, _ = model.get_means_variances() mu = mu.squeeze() mp = model.mp if mn < 0: # GMM sd = torch.sqrt(model.Cov).squeeze() else: # RMM/CMM sd = model.sig.squeeze() # Get std and mean of noise class if mu_noise: # Closest to mu_bg _, ix_noise = torch.min(torch.abs(mu - mu_noise), dim=0) else: # With smallest sd _, ix_noise = torch.min(sd, dim=0) mu_noise = mu[ix_noise] sd_noise = sd[ix_noise] if chi: dof_noise = model.dof[ix_noise] # Get std and mean of other classes (means and sds weighted by mps) rng = list(range(num_class)) del rng[ix_noise] mu = mu[rng] sd = sd[rng] w = mp[rng] w = w / torch.sum(w) mu_not_noise = sum(w * mu) sd_not_noise = sum(w * sd) if chi: dof = model.dof[rng] dof_not_noise = sum(w * dof) # return dictionaries of parameters prm_noise = dict(sd=sd_noise, mean=mu_noise) prm_not_noise = dict(sd=sd_not_noise, mean=mu_not_noise) if chi: prm_noise['dof'] = dof_noise prm_not_noise['dof'] = dof_not_noise return prm_noise, prm_not_noise
def __init__(self, dat, levels=1, affine=None, dim=None, mask=None, preview=None, bound='dct2', extrapolate=False, method='gauss', **backend): """ Parameters ---------- dat : [list of] (..., *shape) tensor or Image levels : int or list[int] or range, default=0 If an int, it is the number of levels. If a range or list, they are the indices of levels to compute. `0` is the native resolution, `1` is half of it, etc. affine : [list of] tensor, optional dim : int, optional bound : str, default='dct2' extrapolate : bool, default=True method : {'gauss', 'average', 'median', 'stride'}, default='gauss' """ # I don't call super().__init__() on purpose self.method = method if isinstance(dat, Image): if affine is None: affine = dat.affine dim = dat.dim mask = dat.mask preview = dat._preview bound = dat.bound extrapolate = dat.extrapolate dat = dat.dat if isinstance(dat, str): dat = io.map(dat) if isinstance(dat, io.MappedArray): if affine is None: affine = dat.affine dat = dat.fdata(rand=True, **backend)[None] if isinstance(levels, int): levels = range(levels) if isinstance(levels, range): levels = list(levels) if not isinstance(dat, (list, tuple)): dim = dim or dat.dim() - 1 if not levels: raise ValueError('levels required to compute pyramid') if isinstance(levels, int): levels = range(1, levels+1) if affine is None: shape = dat.shape[-dim:] affine = spatial.affine_default(shape, **utils.backend(dat[0])) dat, mask, preview = self._build_pyramid( dat, levels, method, dim, bound, mask, preview) dat = list(dat) if not mask: mask = [None] * len(dat) if not preview: preview = [None] * len(dat) if all(isinstance(d, Image) for d in dat): self._dat = dat return dim = dim or (dat[0].dim() - 1) if affine is None: shape = dat[0].shape[-dim:] affine = spatial.affine_default(shape, **utils.backend(dat[0])) if not isinstance(affine, (list, tuple)): if not levels: raise ValueError('levels required to compute affine pyramid') shape = dat[0].shape[-dim:] affine = self._build_affine_pyramid(affine, shape, levels, method) self._dat = [Image(d, aff, dim=dim, mask=m, preview=p, bound=bound, extrapolate=extrapolate) for d, m, p, aff in zip(dat, mask, preview, affine)]
def _read_data(data, sett): """ Parse input data into algorithm input struct(s). Args: data Returns: x (_input()): Algorithm input struct(s). """ # Sanity check mat_vol = sett.mat if isinstance(data, str): file = map(data) dim = file.shape if len(dim) > 3: # Data is path to 4D nifti data = file.fdata() mat_vol = file.affine try: data.shape data = data[..., None] data = data[:, :, :, :, 0] if mat_vol is None: raise ValueError('Image data given as array, please also provide affine matrix in sett.mat!') except AttributeError: pass if isinstance(data, str): data = [data] # Number of channels if mat_vol is not None: C = data.shape[3] else: C = len(data) x = [] for c in range(C): # Loop over channels x.append([]) x[c] = [] if mat_vol is None and isinstance(data[c], list) and (isinstance(data[c][0], str) or isinstance(data[c][0], list)): # Possibly multiple repeats per channel for n in range(len(data[c])): # Loop over observations of channel c x[c].append(_input()) # Get data dat, dim, mat, fname, direc, nam, file, ct = \ _read_image(data[c][n], sett.device, could_be_ct=sett.ct) # Assign x[c][n].dat = dat x[c][n].dim = dim x[c][n].mat = mat x[c][n].fname = fname x[c][n].direc = direc x[c][n].nam = nam x[c][n].file = file x[c][n].ct = ct else: # One repeat per channel n = 0 x[c].append(_input()) # Get data if mat_vol is not None: dat, dim, mat, fname, direc, nam, file, ct = \ _read_image([data[..., c], mat_vol], sett.device, could_be_ct=sett.ct) else: dat, dim, mat, fname, direc, nam, file, ct = \ _read_image(data[c], sett.device, could_be_ct=sett.ct) # Assign x[c][n].dat = dat x[c][n].dim = dim x[c][n].mat = mat x[c][n].fname = fname x[c][n].direc = direc x[c][n].nam = nam x[c][n].file = file x[c][n].ct = ct # Add labels (if given) if sett.label is not None: pth_label = sett.label[0] ix_cr = sett.label[1] # Index channel and repeat for c in range(len(x)): for n in range(len(x[c])): if c == ix_cr[0] and n == ix_cr[1]: x[c][n] = _read_label(x[c][n], pth_label, sett) # Print to screen _print_info('filenames', sett, x) return x
def _main(options): if isinstance(options.gpu, str): device = torch.device(options.gpu) else: assert isinstance(options.gpu, int) device = torch.device(f'cuda:{options.gpu}') if not torch.cuda.is_available(): device = 'cpu' # prepare options estatics_opt = ESTATICSOptions() estatics_opt.likelihood = options.likelihood estatics_opt.verbose = options.verbose >= 1 estatics_opt.plot = options.verbose >= 2 estatics_opt.recon.space = options.space if isinstance(options.space, str) and options.space != 'mean': for c, contrast in enumerate(options.contrast): if contrast.name == options.space: estatics_opt.recon.space = c break estatics_opt.backend.device = device estatics_opt.optim.nb_levels = options.levels estatics_opt.optim.max_iter_rls = options.iter estatics_opt.optim.tolerance = options.tol estatics_opt.regularization.norm = options.regularization estatics_opt.regularization.factor = [*options.lam_intercept, options.lam_decay] estatics_opt.distortion.enable = options.meetup estatics_opt.distortion.bending = options.lam_meetup estatics_opt.preproc.register = options.register # prepare files contrasts = [] distortion = [] for i, c in enumerate(options.contrast): # read meta-parameters meta = {} if c.te: te, unit = c.te, '' if isinstance(te[-1], str): *te, unit = te if unit: if unit == 'ms': te = [t * 1e-3 for t in te] elif unit not in ('s', 'sec'): raise ValueError(f'TE unit: {unit}') if c.echo_spacing: delta, *unit = c.echo_spacing unit = unit[0] if unit else '' if unit == 'ms': delta = delta * 1e-3 elif unit not in ('s', 'sec'): raise ValueError(f'echo spacing unit: {unit}') ne = sum(io.map(f).unsqueeze(-1).shape[3] for f in c.echoes) te = [te[0] + e*delta for e in range(ne)] meta['te'] = te # map volumes contrasts.append(qio.GradientEchoMulti.from_fname(c.echoes, **meta)) if c.readout: layout = spatial.affine_to_layout(contrasts[-1].affine) layout = spatial.volume_layout_to_name(layout) readout = None for j, l in enumerate(layout): if l.lower() in c.readout.lower(): readout = j - 3 contrasts[-1].readout = readout if c.b0: bw = c.bandwidth b0, *unit = c.b0 unit = unit[-1] if unit else 'vx' fb0 = b0.map(b0) b0 = fb0.fdata(device=device) b0 = spatial.reslice(b0, fb0.affine, contrasts[-1][0].affine, contrasts[-1][0].shape) if unit.lower() == 'hz': if not bw: raise ValueError('Bandwidth required to convert fieldmap' 'from Hz to voxel') b0 /= bw b0 = DenseDistortion(b0) distortion.append(b0) else: distortion.append(None) # run algorithm [te0, r2s, *b0] = estatics(contrasts, distortion, opt=estatics_opt) # write results # --- intercepts --- odir0 = options.odir for i, te1 in enumerate(te0): ifname = contrasts[i].echo(0).volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir obase = obase + '_TE0' ofname = os.path.join(odir, obase + oext) io.savef(te1.volume, ofname, affine=te1.affine, like=ifname, te=0, dtype='float32') # --- decay --- ifname = contrasts[0].echo(0).volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir io.savef(r2s.volume, os.path.join(odir, 'R2star' + oext), affine=r2s.affine, dtype='float32') # --- fieldmap + undistorted --- if b0: b0 = b0[0] for i, b01 in enumerate(b0): ifname = contrasts[i].echo(0).volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir obase = obase + '_B0' ofname = os.path.join(odir, obase + oext) io.savef(b01.volume, ofname, affine=b01.affine, like=ifname, te=0, dtype='float32') for i, (c, b) in enumerate(zip(contrasts, b0)): readout = c.readout grid_up, grid_down, jac_up, jac_down = b.exp2( add_identity=True, jacobian=True) for j, e in enumerate(c): blip = e.blip or (2*(j % 2) - 1) grid_blip = grid_down if blip > 0 else grid_up # inverse of jac_blip = jac_down if blip > 0 else jac_up # forward model ifname = e.volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir obase = obase + '_unwrapped' ofname = os.path.join(odir, obase + oext) d = e.fdata(device=device) d, _ = pull1d(d, grid_blip, readout) d *= jac_blip io.savef(d, ofname, affine=e.affine, like=ifname) del d del grid_up, grid_down, jac_up, jac_down if options.register: for i, c in enumerate(contrasts): for j, e in enumerate(c): ifname = e.volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir obase = obase + '_registered' ofname = os.path.join(odir, obase + oext) io.save(e.volume, ofname, affine=e.affine)
def main(options): # find readout direction f = io.map(options.echoes[0]) affine, shape = f.affine, f.shape readout = get_readout(options.direction, affine, shape, options.verbose) if not options.reversed: reversed_echoes = options.synth else: reversed_echoes = options.reversed # do EPIC fit = epic(options.echoes, reverse_echoes=reversed_echoes, fieldmap=options.fieldmap, extrapolate=options.extrapolate, bandwidth=options.bandwidth, polarity=options.polarity, readout=readout, slicewise=options.slicewise, lam=options.penalty, max_iter=options.maxiter, tol=options.tolerance, verbose=options.verbose, device=get_device(options.gpu)) # save volumes input, output = options.echoes, options.output if len(output) != len(input): if len(output) == 1: if '{base}' in output[0]: output = [output[0]] * len(input) elif len(output) != len(fit): raise ValueError(f'There should be either one output file, ' f'or as many output files as input files, ' f'or as many output files as echoes. Got ' f'{len(output)} output files, {len(input)} ' f'input files, and {len(fit)} echoes.') if len(output) == 1: dir, base, ext = py.fileparts(input[0]) output = output[0] if '{n}' in output: for n, echo in enumerate(fit): out = output.format(dir=dir, sep=os.sep, base=base, ext=ext, n=n) io.savef(echo, out, like=input[0]) else: output = output.format(dir=dir, sep=os.sep, base=base, ext=ext) io.savef(torch.movedim(fit, 0, -1), output, like=input[0]) elif len(output) == len(input): for i, (inp, out) in enumerate(zip(input, output)): dir, base, ext = py.fileparts(inp) out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=i) ne = [*io.map(inp).shape, 1][3] io.savef(fit[:ne].movedim(0, -1), out, like=inp) fit = fit[ne:] else: assert len(output) == len(fit) dir, base, ext = py.fileparts(input[0]) for n, (echo, out) in enumerate(zip(fit, output)): out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=n) io.savef(echo, out, like=input[0])
def _read_image(data, device='cpu', could_be_ct=False): """ Reads image data. Args: data (string|list): Path to file, or list with image data and affine matrix. device (string, optional): PyTorch on CPU or GPU? Defaults to 'cpu'. could_be_ct (bool, optional): Could the image be a CT scan? Returns: dat (torch.tensor()): Image data. dim (tuple(int)): Image dimensions. mat (torch.tensor(double)): Affine matrix. fname (string): File path direc (string): File directory path nam (string): Filename file (io.BabelArray) ct (bool): Is data CT """ if isinstance(data, str): # ================================= # Load from file # ================================= file = map(data) dat = file.fdata(dtype=torch.float32, device=device, rand=False, cutoff=None) mat = file.affine.to(device).type(torch.float64) fname = file.filename() direc, nam = os.path.split(os.path.abspath(fname)) else: # ================================= # Data and matrix given as list # ================================= # Image data dat = data[0] if not isinstance(dat, torch.Tensor): dat = torch.tensor(dat) dat = dat.float() dat = dat.to(device) dat[~torch.isfinite(dat)] = 0 # Add some random noise torch.manual_seed(0) dat[dat > 0] += torch.rand_like(dat[dat > 0]) - 1 / 2 # Affine matrix mat = data[1] if not isinstance(mat, torch.Tensor): mat = torch.tensor(mat) mat = mat.double().to(device) file = None fname = None direc = None nam = None # Get dimensions dim = tuple(dat.shape) # CT? if could_be_ct and _is_ct(dat): ct = True else: ct = False # Mask dat[~torch.isfinite(dat)] = 0.0 return dat, dim, mat, fname, direc, nam, file, ct
def estimate_fwhm(dat, vx=None, verbose=0, mn=-inf, mx=inf): """Estimates full width at half maximum (FWHM) and noise standard deviation (sd) of a 2D or 3D image. It is assumed that the image has been generated as: dat = Ky + n, where K is Gaussian smoothing with some FWHM and n is additive Gaussian noise. FWHM and n are estimated. Parameters ---------- dat : str or (*spatial) tensor Image data or path to nifti file vx : [sequence of] float, default=1 Voxel size verbose : {0, 1, 2}, default=0 Verbosity level: * 0: No verbosity * 1: Print FWHM and sd to screen * 2: 1 + show mask mn : float, optional Exclude values below mx : float, optional Exclude values above Returns ------- fwhm : (dim,) tensor Estimated FWHM sd : scalar tensor Estimated noise standard deviation. References ---------- ..[1] "Linked independent component analysis for multimodal data fusion." Appendix A Groves AR, Beckmann CF, Smith SM, Woolrich MW. Neuroimage. 2011 Feb 1;54(3):2198-217. """ if isinstance(dat, str): dat = io.map(dat) if isinstance(dat, io.MappedArray): if vx is None: vx = get_voxel_size(dat.affine) dat = dat.fdata(rand=True, missing=0) dat = torch.as_tensor(dat) dim = dat.dim() if vx is None: vx = 1 vx = utils.make_vector(vx, dim) backend = utils.backend(dat) # Make mask msk = (dat > mn).bitwise_and_(dat <= mx) dat = dat.masked_fill(~msk, 0) # TODO: we should erode the mask so that only voxels whose neighbours # are in the mask are considered when computing gradients. if verbose >= 2: show_slices(msk) # Compute image gradient g = diff(dat, dim=range(dim), side='central', voxel_size=vx, bound='dft').abs_() slicer = (slice(1, -1), ) * dim g = g[(*slicer, None)] g[msk[slicer], :] = 0 g = g.reshape([-1, dim]).sum(0, dtype=torch.double) # Make dat have zero mean dat = dat[slicer] dat = dat[msk[slicer]] x0 = dat - dat.mean() # Compute FWHM fwhm = pymath.sqrt(4 * pymath.log(2)) * x0.abs().sum(dtype=torch.double) fwhm = fwhm / g if verbose >= 1: print(f'FWHM={fwhm.tolist()}') # Compute noise standard deviation sx = smooth('gauss', fwhm[0], x=0, **backend)[0][0, 0, 0] sy = smooth('gauss', fwhm[1], x=0, **backend)[0][0, 0, 0] sz = 1.0 if dim == 3: sz = smooth('gauss', fwhm[2], x=0, **backend)[0][0, 0, 0] sc = (sx * sy * sz) / dim sc.clamp_min_(1) sd = torch.sqrt(x0.square().sum(dtype=torch.double) / (x0.numel() * sc)) if verbose >= 1: print(f'sd={sd.tolist()}') return fwhm, sd
def write_outputs(z, prm, options): # prepare filenames ref_native = options.input[0] ref_mni = options.tpm[0] if options.tpm else path_spm_prior() format_dict = get_format_dict(ref_native, options.output) # move channels to back backend = utils.backend(z) if (options.nobias_nat or options.nobias_mni or options.nobias_wrp or options.all_nat or options.all_mni or options.all_wrp): dat, _, affine = get_data(options.input, options.mask, None, 3, **backend) # --- native space ------------------------------------------------- if options.prob_nat or options.all_nat: fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.nat ->', fname) io.savef(torch.movedim(z, 0, -1), fname, like=ref_native, dtype='float32') if options.labels_nat or options.all_nat: fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.nat ->', fname) io.save(z.argmax(0), fname, like=ref_native, dtype='int16') if (options.bias_nat or options.all_nat) and options.bias: bias = prm['bias'] fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.nat ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.nat.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, dtype='float32') del bias if (options.nobias_nat or options.all_nat) and options.bias: nobias = dat * prm['bias'] fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.nat ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.nat.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1) del nobias if (options.warp_nat or options.all_nat) and options.warp: warp = prm['warp'] fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.nat ->', fname) io.savef(warp, fname, like=ref_native, dtype='float32') # --- MNI space ---------------------------------------------------- if options.tpm is False: # No template -> no MNI space return fref = io.map(ref_mni) mni_affine, mni_shape = fref.affine, fref.shape[:3] dat_affine = io.map(ref_native).affine mni_affine = mni_affine.to(**backend) dat_affine = dat_affine.to(**backend) prm_affine = prm['affine'].to(**backend) dat_affine = prm_affine @ dat_affine if options.mni_vx: vx = spatial.voxel_size(mni_affine) scl = vx / options.mni_vx mni_affine, mni_shape = spatial.affine_resize(mni_affine, mni_shape, scl, anchor='f') if options.prob_mni or options.labels_mni or options.all_mni: z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape) if options.prob_mni: fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.mni ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni: fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.mni ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_mni or options.nobias_mni or options.all_mni): bias = spatial.reslice(prm['bias'], dat_affine, mni_affine, mni_shape, interpolation=3, prefilter=False, bound='dct2') if options.bias_mni or options.all_mni: fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.mni ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.mni.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_mni or options.all_mni: nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape) nobias *= bias fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.mni ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.mni.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp or options.bias_wrp or options.nobias_wrp or options.all_mni or options.all_wrp) need_iwarp = need_iwarp and options.warp if not need_iwarp: return iwarp = spatial.grid_inv(prm['warp'], type='disp') iwarp = iwarp.movedim(-1, 0) iwarp = spatial.reslice(iwarp, dat_affine, mni_affine, mni_shape, interpolation=2, bound='dft', extrapolate=True) iwarp = iwarp.movedim(0, -1) iaff = mni_affine.inverse() @ dat_affine iwarp = linalg.matvec(iaff[:3, :3], iwarp) if (options.warp_mni or options.all_mni) and options.warp: fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.mni ->', fname) io.savef(iwarp, fname, like=ref_native, affine=mni_affine, dtype='float32') # --- Warped space ------------------------------------------------- iwarp = spatial.add_identity_grid_(iwarp) iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp) if options.prob_wrp or options.labels_wrp or options.all_wrp: z_mni = spatial.grid_pull(z, iwarp) if options.prob_mni or options.all_wrp: fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.wrp ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni or options.all_wrp: fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.wrp ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_wrp or options.nobias_wrp or options.all_wrp): bias = spatial.grid_pull(prm['bias'], iwarp, interpolation=3, prefilter=False, bound='dct2') if options.bias_wrp or options.all_wrp: fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.wrp ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.wrp.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_wrp or options.all_wrp: nobias = spatial.grid_pull(dat, iwarp) nobias *= bias fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.wrp ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.wrp.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias
def main_apply(options): """ Unwarp distorted images using a pre-computed 1d displacement field. """ device = get_device(options.gpu) # detect readout direction if options.file_pos: f0 = io.map(options.file_pos[0]) else: f0 = io.map(options.file_neg[0]) dim = f0.affine.shape[-1] - 1 readout = get_readout(options.readout, f0.affine, f0.shape[-dim:]) def do_apply(fnames, phi, jac): """Correct files with a given polarity""" for fname in fnames: dir, base, ext = py.fileparts(fname) ofname = options.output ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext) if options.verbose: print(f'unwarp {fname} \n' f' -> {ofname}') f = io.map(fname) d = f.fdata(device=device) d = utils.movedim(d, readout, -1) d = _deform1d(d, phi) if jac is not None: d *= jac d = utils.movedim(d, -1, readout) io.savef(d, ofname, like=fname) # load and apply vel = io.loadf(options.dist_file, device=device) vel = utils.movedim(vel, readout, -1) if options.file_pos: if options.diffeo: phi, *jac = spatial.exp1d_forward(vel, bound='dct2', jacobian=options.modulation) jac = jac[0] if jac else None else: phi = vel.clone() jac = None if options.modulation: jac = spatial.diff1d(phi, dim=readout, bound='dct2', side='c') jac += 1 phi = spatial.add_identity_grid_(phi.unsqueeze(-1)).squeeze(-1) do_apply(options.file_pos, phi, jac) if options.file_neg: if options.diffeo: phi, *jac = spatial.exp1d_forward(-vel, bound='dct2', jacobian=options.modulation) jac = jac[0] if jac else None else: phi = -vel jac = None if options.modulation: jac = spatial.diff1d(phi, dim=readout, bound='dct2', side='c') jac += 1 phi = spatial.add_identity_grid_(phi.unsqueeze(-1)).squeeze(-1) do_apply(options.file_neg, phi, jac)
def main_fit(options): """ Estimate a displacement field from opposite polarity images """ device = get_device(options.gpu) # map input files f0 = io.map(options.pos_file) f1 = io.map(options.neg_file) dim = f0.affine.shape[-1] - 1 # map mask fm = None if options.mask: fm = io.map(options.mask) # detect readout direction readout = get_readout(options.readout, f0.affine, f0.shape[-dim:]) # detect penalty penalty_type = 'bending' penalties = options.penalty if penalties and isinstance(penalties[-1], str): *penalties, penalty_type = penalties if not penalties: penalties = [1] if penalty_type[0] == 'b': penalty_type = 'bending' elif penalty_type[0] == 'm': penalty_type = 'membrane' else: raise ValueError('Unknown penalty type', penalty_type) downs = options.downsample max_iter = options.max_iter tolerance = options.tolerance nb_levels = max(len(penalties), len(max_iter), len(tolerance), len(downs)) penalties = py.make_list(penalties, nb_levels) tolerance = py.make_list(tolerance, nb_levels) max_iter = py.make_list(max_iter, nb_levels) downs = py.make_list(downs, nb_levels) # load d00 = f0.fdata(device='cpu') d11 = f1.fdata(device='cpu') dmask = fm.fdata(device='cpu') if fm else None # fit vel = mask = None aff = last_aff = f0.affine last_dwn = None for penalty, n, tol, dwn in zip(penalties, max_iter, tolerance, downs): if dwn != last_dwn: d0, aff = downsample(d00.to(device), f0.affine, dwn) d1, _ = downsample(d11.to(device), f1.affine, dwn) vx = spatial.voxel_size(aff) if vel is not None: vel = upsample_vel(vel, last_aff, aff, d0.shape[-dim:], readout) last_aff = aff if fm: mask, _ = downsample(dmask.to(device), f1.affine, dwn) last_dwn = dwn scl = py.prod(d00.shape) / py.prod(d0.shape) penalty = penalty * scl kernel = get_kernel(options.kernel, aff, d0.shape[-dim:], dwn) # prepare loss if options.loss == 'mse': prm0, _ = estimate_noise(d0) prm1, _ = estimate_noise(d1) sd = ((prm0['sd'].log() + prm1['sd'].log())/2).exp() print(sd.item()) loss = MSE(lam=1/(sd*sd), dim=dim) elif options.loss == 'lncc': loss = LNCC(dim=dim, patch=kernel) elif options.loss == 'lgmm': if options.bins == 1: loss = LNCC(dim=dim, patch=kernel) else: loss = LGMMH(dim=dim, patch=kernel, bins=options.bins) elif options.loss == 'gmm': if options.bins == 1: loss = NCC(dim=dim) else: loss = GMMH(dim=dim, bins=options.bins) else: loss = NCC(dim=dim) # fit vel = topup_fit(d0, d1, loss=loss, dim=readout, vx=vx, ndim=dim, model=('svf' if options.diffeo else 'smalldef'), lam=penalty, penalty=penalty_type, vel=vel, modulation=options.modulation, max_iter=n, tolerance=tol, verbose=options.verbose, mask=mask) del d0, d1, d00, d11 # upsample vel = upsample_vel(vel, aff, f0.affine, f0.shape[-dim:], readout) # save dir, base, ext = py.fileparts(options.pos_file) fname = options.output fname = fname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext) io.savef(vel, fname, like=options.pos_file, dtype='float32')