def cat(dats, dim=None): if any(torch.is_tensor(dat) for dat in dats): dtype = utils.max_dtype(dats) device = utils.max_device(dats) dats = [torch.as_tensor(dat, dtype=dtype, device=device) for dat in dats] return torch.cat(dats, dim=dim) else: return np.concatenate(dats, axis=dim)
def inpaint(*inputs, missing='nan', output=None, device=None, verbose=1, max_iter_rls=10, max_iter_cg=32, tol_rls=1e-5, tol_cg=1e-5): """Inpaint missing values by minimizing Joint Total Variation. Parameters ---------- *inputs : str or tensor or (tensor, tensor) Either a path to a volume file or a tuple `(dat, affine)`, where the first element contains the volume data and the second contains the orientation matrix. missing : 'nan' or scalar or callable, default='nan' Mask of the missing data. If a scalar, all voxels with that value are considered missing. If a function, it should return the mask of missing values when applied to the multi-channel data. Else, non-finite values are assumed missing. output : [sequence of] str, optional Output filename(s). If the input is not a path, the unstacked data is not written on disk by default. If the input is a path, the default output filename is '{dir}/{base}.pool{ext}', where `dir`, `base` and `ext` are the directory, base name and extension of the input file, `i` is the coordinate (starting at 1) of the slice. verbose : int, default=1 device : torch.device, optional max_iter_rls : int, default=10 max_iter_cg : int, default=32 tol_rls : float, default=1e-5 tol_cg : float, default=1e-5 Returns ------- *output : str or (tensor, tensor) If the input is a path, the output path is returned. Else, the pooled data and orientation matrix are returned. """ # Preprocess dirs = [] bases = [] exts = [] fnames = [] nchannels = [] dat = [] aff = None for i, inp in enumerate(inputs): is_file = isinstance(inp, str) if is_file: fname = inp dir, base, ext = py.fileparts(fname) fnames.append(inp) dirs.append(dir) bases.append(base) exts.append(ext) f = io.volumes.map(fname) if aff is None: aff = f.affine f = ensure_4d(f) dat.append(f.fdata(device=device)) else: fnames.append(None) dirs.append('') bases.append(f'{i+1}') exts.append('.nii.gz') if isinstance(inp, (list, tuple)): if aff is None: dat1, aff = inp else: dat1, _ = inp else: dat1 = inp dat.append(torch.as_tensor(dat1, device=device)) del dat1 nchannels.append(dat[-1].shape[-1]) dat = utils.to(*dat, dtype=torch.float, device=utils.max_device(dat)) if not torch.is_tensor(dat): dat = torch.cat(dat, dim=-1) dat = utils.movedim(dat, -1, 0) # (channels, *spatial) # Set missing data if missing != 'nan': if not callable(missing): missingval = utils.make_vector(missing, dtype=dat.dtype, device=dat.device) missing = lambda x: utils.isin(x, missingval) dat[missing(dat)] = nan dat[~torch.isfinite(dat)] = nan # Do it if aff is not None: vx = spatial.voxel_size(aff) else: vx = 1 dat = do_inpaint(dat, voxel_size=vx, verbose=verbose, max_iter_rls=max_iter_rls, tol_rls=tol_rls, max_iter_cg=max_iter_cg, tol_cg=tol_cg) # Postprocess dat = utils.movedim(dat, 0, -1) dat = dat.split(nchannels, dim=-1) output = py.make_list(output, len(dat)) for i in range(len(dat)): if fnames[i] and not output[i]: output[i] = '{dir}{sep}{base}.inpaint{ext}' if output[i]: if fnames[i]: output[i] = output[i].format(dir=dirs[i] or '.', base=bases[i], ext=exts[i], sep=os.path.sep) io.volumes.save(dat[i], output[i], like=fnames[i], affine=aff) else: output[i] = output[i].format(sep=os.path.sep) io.volumes.save(dat[i], output[i], affine=aff) dat = [ output[i] if fnames[i] else (dat[i], aff) if aff is not None else dat[i] for i in range(len(dat)) ] if len(dat) == 1: dat = dat[0] return dat