def _write_image(dat, fname, bids=False, mat=torch.eye(4), file=None, dtype='float32'): """ Write data to nifti. """ if bids: p, n = os.path.split(fname) s = n.split('_') fname = os.path.join(p, '_'.join(s[:-1] + ['space-unires'] + [s[-1]])) savef(dat, fname, like=file, affine=mat)
def _write_output(dat, mat, file=None, prefix='', odir=None, nam=None): """Write preprocessed output to disk. """ if odir is not None: os.makedirs(odir, exist_ok=True) pth = [] for n in range(len(dat)): if file[n] is not None: filename = file[n].filename() else: filename = 'nitorch_file' pth.append(file_mod(filename, odir=odir, prefix=prefix, nam=nam)) savef(dat[n], pth[n], like=file[n], affine=mat[n]) if len(dat) == 1: pth = pth[0] return pth
def do_apply(fnames, phi, jac): """Correct files with a given polarity""" for fname in fnames: dir, base, ext = py.fileparts(fname) ofname = options.output ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext) if options.verbose: print(f'unwarp {fname} \n' f' -> {ofname}') f = io.map(fname) d = f.fdata(device=device) d = utils.movedim(d, readout, -1) d = _deform1d(d, phi) if jac is not None: d *= jac d = utils.movedim(d, -1, readout) io.savef(d, ofname, like=fname)
def _cli(args): """Command-line interface for `smooth` without exception handling""" args = args or sys.argv[1:] options = parser(args) if options.help: print(help) return fwhm = options.fwhm unit = 'mm' if isinstance(fwhm[-1], str): *fwhm, unit = fwhm fwhm = make_list(fwhm, 3) options.output = make_list(options.output, len(options.files)) for fname, ofname in zip(options.files, options.output): f = io.map(fname) vx = voxel_size(f.affine).tolist() dim = len(vx) if unit == 'mm': fwhm1 = [f / v for f, v in zip(fwhm, vx)] else: fwhm1 = fwhm[:len(vx)] dat = f.fdata() dat = movedim_front2back(dat, dim) dat = smooth(dat, type=options.method, fwhm=fwhm1, basis=options.basis, bound=options.padding, dim=dim) dat = movedim_back2front(dat, dim) folder, base, ext = fileparts(fname) ofname = ofname.format(dir=folder or '.', base=base, ext=ext, sep=os.path.sep) io.savef(dat, ofname, like=f)
def _main(options): if isinstance(options.gpu, str): device = torch.device(options.gpu) else: assert isinstance(options.gpu, int) device = torch.device(f'cuda:{options.gpu}') if not torch.cuda.is_available(): device = 'cpu' # prepare options estatics_opt = ESTATICSOptions() estatics_opt.likelihood = options.likelihood estatics_opt.verbose = options.verbose >= 1 estatics_opt.plot = options.verbose >= 2 estatics_opt.recon.space = options.space if isinstance(options.space, str) and options.space != 'mean': for c, contrast in enumerate(options.contrast): if contrast.name == options.space: estatics_opt.recon.space = c break estatics_opt.backend.device = device estatics_opt.optim.nb_levels = options.levels estatics_opt.optim.max_iter_rls = options.iter estatics_opt.optim.tolerance = options.tol estatics_opt.regularization.norm = options.regularization estatics_opt.regularization.factor = [*options.lam_intercept, options.lam_decay] estatics_opt.distortion.enable = options.meetup estatics_opt.distortion.bending = options.lam_meetup estatics_opt.preproc.register = options.register # prepare files contrasts = [] distortion = [] for i, c in enumerate(options.contrast): # read meta-parameters meta = {} if c.te: te, unit = c.te, '' if isinstance(te[-1], str): *te, unit = te if unit: if unit == 'ms': te = [t * 1e-3 for t in te] elif unit not in ('s', 'sec'): raise ValueError(f'TE unit: {unit}') if c.echo_spacing: delta, *unit = c.echo_spacing unit = unit[0] if unit else '' if unit == 'ms': delta = delta * 1e-3 elif unit not in ('s', 'sec'): raise ValueError(f'echo spacing unit: {unit}') ne = sum(io.map(f).unsqueeze(-1).shape[3] for f in c.echoes) te = [te[0] + e*delta for e in range(ne)] meta['te'] = te # map volumes contrasts.append(qio.GradientEchoMulti.from_fname(c.echoes, **meta)) if c.readout: layout = spatial.affine_to_layout(contrasts[-1].affine) layout = spatial.volume_layout_to_name(layout) readout = None for j, l in enumerate(layout): if l.lower() in c.readout.lower(): readout = j - 3 contrasts[-1].readout = readout if c.b0: bw = c.bandwidth b0, *unit = c.b0 unit = unit[-1] if unit else 'vx' fb0 = b0.map(b0) b0 = fb0.fdata(device=device) b0 = spatial.reslice(b0, fb0.affine, contrasts[-1][0].affine, contrasts[-1][0].shape) if unit.lower() == 'hz': if not bw: raise ValueError('Bandwidth required to convert fieldmap' 'from Hz to voxel') b0 /= bw b0 = DenseDistortion(b0) distortion.append(b0) else: distortion.append(None) # run algorithm [te0, r2s, *b0] = estatics(contrasts, distortion, opt=estatics_opt) # write results # --- intercepts --- odir0 = options.odir for i, te1 in enumerate(te0): ifname = contrasts[i].echo(0).volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir obase = obase + '_TE0' ofname = os.path.join(odir, obase + oext) io.savef(te1.volume, ofname, affine=te1.affine, like=ifname, te=0, dtype='float32') # --- decay --- ifname = contrasts[0].echo(0).volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir io.savef(r2s.volume, os.path.join(odir, 'R2star' + oext), affine=r2s.affine, dtype='float32') # --- fieldmap + undistorted --- if b0: b0 = b0[0] for i, b01 in enumerate(b0): ifname = contrasts[i].echo(0).volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir obase = obase + '_B0' ofname = os.path.join(odir, obase + oext) io.savef(b01.volume, ofname, affine=b01.affine, like=ifname, te=0, dtype='float32') for i, (c, b) in enumerate(zip(contrasts, b0)): readout = c.readout grid_up, grid_down, jac_up, jac_down = b.exp2( add_identity=True, jacobian=True) for j, e in enumerate(c): blip = e.blip or (2*(j % 2) - 1) grid_blip = grid_down if blip > 0 else grid_up # inverse of jac_blip = jac_down if blip > 0 else jac_up # forward model ifname = e.volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir obase = obase + '_unwrapped' ofname = os.path.join(odir, obase + oext) d = e.fdata(device=device) d, _ = pull1d(d, grid_blip, readout) d *= jac_blip io.savef(d, ofname, affine=e.affine, like=ifname) del d del grid_up, grid_down, jac_up, jac_down if options.register: for i, c in enumerate(contrasts): for j, e in enumerate(c): ifname = e.volume.fname odir, obase, oext = py.fileparts(ifname) odir = odir0 or odir obase = obase + '_registered' ofname = os.path.join(odir, obase + oext) io.save(e.volume, ofname, affine=e.affine)
def _warp_image(option, affine=None, nonlin=None, dim=None, device=None, odir=None): """Warp and save the moving and fixed images from a loss object""" if not (option.mov.output or option.mov.resliced or option.fix.output or option.fix.resliced): return fix, fix_affine = _map_image(option.fix.files, dim=dim) mov, mov_affine = _map_image(option.mov.files, dim=dim) fix_affine = fix_affine.float() mov_affine = mov_affine.float() dim = dim or (fix.dim - 1) if option.fix.world: # overwrite orientation matrix fix_affine = io.transforms.map(option.fix.world).fdata().squeeze() for transform in (option.fix.affine or []): transform = io.transforms.map(transform).fdata().squeeze() fix_affine = spatial.affine_lmdiv(transform, fix_affine) if option.mov.world: # overwrite orientation matrix mov_affine = io.transforms.map(option.mov.world).fdata().squeeze() for transform in (option.mov.affine or []): transform = io.transforms.map(transform).fdata().squeeze() mov_affine = spatial.affine_lmdiv(transform, mov_affine) # moving if option.mov.output or option.mov.resliced: ifname = option.mov.files[0] idir, base, ext = py.fileparts(ifname) odir_mov = odir or idir or '.' image = objects.Image(mov.fdata(rand=True, device=device), dim=dim, affine=mov_affine, bound=option.mov.bound, extrapolate=option.mov.extrapolate) if option.mov.output: target_affine = mov_affine target_shape = image.shape if affine and affine.position[0].lower() in 'ms': aff = affine.exp(recompute=False, cache_result=True) target_affine = spatial.affine_lmdiv(aff, target_affine) fname = option.mov.output.format(dir=odir_mov, base=base, sep=os.path.sep, ext=ext) print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped if option.mov.resliced: target_affine = fix_affine target_shape = fix.shape[1:] fname = option.mov.resliced.format(dir=odir_mov, base=base, sep=os.path.sep, ext=ext) print(f'Full reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, reslice=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped # fixed if option.fix.output or option.fix.resliced: ifname = option.fix.files[0] idir, base, ext = py.fileparts(ifname) odir_fix = odir or idir or '.' image = objects.Image(fix.fdata(rand=True, device=device), dim=dim, affine=fix_affine, bound=option.fix.bound, extrapolate=option.fix.extrapolate) if option.fix.output: target_affine = fix_affine target_shape = image.shape if affine and affine.position[0].lower() in 'fs': aff = affine.exp(recompute=False, cache_result=True) target_affine = spatial.affine_matmul(aff, target_affine) fname = option.fix.output.format(dir=odir_fix, base=base, sep=os.path.sep, ext=ext) print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, backward=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped if option.fix.resliced: target_affine = mov_affine target_shape = mov.shape[1:] fname = option.fix.resliced.format(dir=odir_fix, base=base, sep=os.path.sep, ext=ext) print(f'Full reslice: {ifname} -> {fname} ...', end=' ') warped = _warp_image1(image, target_affine, target_shape, affine=affine, nonlin=nonlin, backward=True, reslice=True) io.savef(warped, fname, like=ifname, affine=target_affine) print('done.') del warped
def _main(options): device = setup_device(*options.device) dim = 3 # ------------------------------------------------------------------ # COMPUTE PYRAMID # ------------------------------------------------------------------ pyramids = _prepare_pyramid_levels(options.loss, options.pyramid, dim) # ------------------------------------------------------------------ # BUILD LOSSES # ------------------------------------------------------------------ loss_list, image_dict = _build_losses(options, pyramids, device) can_use_2nd_order = all(loss.loss.order >= 2 for loss in loss_list) # ------------------------------------------------------------------ # BUILD AFFINE # ------------------------------------------------------------------ affine, affine_optim = _build_affine(options, can_use_2nd_order) # ------------------------------------------------------------------ # BUILD DENSE # ------------------------------------------------------------------ nonlin, nonlin_optim = _build_nonlin(options, can_use_2nd_order, affine, image_dict) if not affine and not nonlin: raise ValueError('At least one of @affine or @nonlin must be used.') # ------------------------------------------------------------------ # BACKEND STUFF # ------------------------------------------------------------------ if options.verbose > 1: import matplotlib matplotlib.use('TkAgg') # local losses may benefit from selecting the best conv torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True # ------------------------------------------------------------------ # PERFORM REGISTRATION # ------------------------------------------------------------------ _do_register(loss_list, affine, nonlin, affine_optim, nonlin_optim, options) # ------------------------------------------------------------------ # WRITE RESULTS # ------------------------------------------------------------------ if affine: affine = affine[-1] if affine and options.affine.output: odir = options.odir or py.fileparts( options.loss[0].fix.files[0])[0] or '.' fname = options.affine.output.format(dir=odir, sep=os.path.sep, name=options.affine.name) print('Affine ->', fname) aff = affine.exp(cache_result=True, recompute=False) io.transforms.savef(aff.cpu(), fname, type=1) # 1 = RAS_TO_RAS if nonlin and options.nonlin.output: odir = options.odir or py.fileparts( options.loss[0].fix.files[0])[0] or '.' fname = options.nonlin.output.format(dir=odir, sep=os.path.sep, name=options.nonlin.name) io.savef(nonlin.dat.dat, fname, affine=nonlin.affine) print('Nonlin ->', fname) for loss in options.loss: _warp_image(loss, affine=affine, nonlin=nonlin, dim=dim, device=device, odir=options.odir)
def denoise_mri(*dat_x, affine_x=None, lam_scl=opt.lam_scl, lr=opt.learning_rate, max_iter=opt.max_iter, tolerance=opt.tolerance, verbose=opt.verbose, device=opt.device, do_write=opt.do_write, dir_out=opt.dir_out): """Denoises a multi-channel MR image by solving: dat_y_hat = 0.5*sum_c(tau_c*sum_i((dat_x_ci - dat_y_ci)^2)) + jtv(dat_y_1, ..., dat_y_C; lam) using PyTorch's auto-diff. If input is given as paths to files, outputs prefixed 'den_' is written based on options 'dir_out' and 'do_write'. Reference: Brudfors, Mikael, et al. "MRI super-resolution using multi-channel total variation." Annual Conference on Medical Image Understanding and Analysis. Springer, Cham, 2018. Parameters ---------- dat_x : (nchannels, dmx, dmy, dmz) tensor or sequence[str] Input noisy image data affine_x : (4, 4) tensor, optional Input images' affine matrix. If not given, assumes identity. lam_scl : float, default=10.0 Scaling of regularisation values lr : float, default=1e1 Optimiser learning rate max_iter : int, default=10000 Maximum number of fitting iterations tolerance : float, default=1e-8 Convergence threshold (when to stop iterating) verbose : bool, default=True Print to terminal? device : torch.device, default='cuda' Torch device do_write : bool, default=True If input is given as paths to files, output is written to disk, prefixed 'den_' to 'dir_out' dir_out : str, optional Directory where to write output, default is same as input. Returns ---------- dat_y_hat : (nchannels, dmx, dmy, dmz) tensor Denoised image data """ # read data from disk if isinstance(dat_x,(list, tuple)) and \ sum(isinstance(dat_x[i],str) for i in range(len(dat_x))) == len(dat_x): dat_x, affine_x, nii = _get_image_data(dat_x, device=device) else: do_write = False # input is tensor, do not write to disk # backend device = dat_x.device dtype = dat_x.dtype # estimate hyper-parameters tau = torch.zeros(dat_x.shape[0], device=device, dtype=dtype) lam = torch.zeros(dat_x.shape[0], device=device, dtype=dtype) for i in range(dat_x.shape[0]): prm0, prm1 = estimate_noise(dat_x[i, ...], show_fit=False) sd_bg = prm0['sd'] mean_fg = prm1['mean'] tau[i] = 1 / sd_bg.float()**2 lam[i] = math.sqrt(1 / dat_x.shape[0]) / mean_fg.float( ) # modulates with number of channels (as in JTV reg) # print("tau={:}".format(tau)) # print("lam={:}".format(lam)) # affine matrices if affine_x is None: affine_x = torch.eye(4, device=device, dtype=dtype) # voxel size vx = voxel_size(affine_x) # scale regularisation lam = lam_scl * lam # initial estimate of reconstruction dat_y_hat = torch.zeros_like(dat_x) dat_y_hat = torch.nn.Parameter(dat_y_hat, requires_grad=True) # prepare optimiser and scheduler optim = torch.optim.Adam([dat_y_hat], lr=lr) # Adam # optim = torch.optim.SGD([dat_y_hat], lr=lr, momentum=0.9) # SGD scheduler = ReduceLROnPlateau(optim) # optimisation loop loss_vals = torch.zeros(max_iter + 1, dtype=torch.float64) cnt_conv = 0 for n_iter in range(1, max_iter + 1): # set gradients to zero (PyTorch accumulates the gradients on subsequent backward passes) optim.zero_grad() # compute reconstruction loss loss_val = _loss_ssqd_jtv(dat_x, dat_y_hat, tau, lam, vx=vx) # differentiate reconstruction loss w.r.t. dat_y_hat loss_val.backward() # store loss loss_vals[n_iter] = loss_val.item() # update reconstruction optim.step() # compute gain gain = get_gain(loss_vals[:n_iter + 1], monotonicity='decreasing') if verbose: # print to screen with torch.no_grad(): if n_iter % 25 == 0: print('n_iter={:4d}, loss={:12.6f}, gain={:0.10}, lr={:g}'. \ format(n_iter, loss_val.item(), gain, optim.param_groups[0]['lr']), end='\n') # end='\r' if n_iter > 10 and gain.abs() < tolerance: cnt_conv += 1 if cnt_conv == 5: # finished break else: cnt_conv = 0 # incorporate scheduler if scheduler is not None and n_iter % 10 == 0: if isinstance(scheduler, ReduceLROnPlateau): scheduler.step(loss_val) else: scheduler.step() if do_write: # write output to disk if dir_out is not None: os.makedirs(dir_out, exist_ok=True) for i in range(dat_y_hat.shape[0]): fname = file_replace(nii.fname, prefix='den_', dir=dir_out, suffix='_' + str(i)) savef(dat_y_hat[i, ...], fname, like=nii) return dat_y_hat
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None, interpolation=1, bound='dct2', extrapolate=False, device=None, verbose=True): """Apply the linear and non-linear components of the transform and reslice the image to the target space. Notes ----- .. The shape and general orientation of the moving image is kept untouched. .. The linear transform is composed with the original orientation matrix. .. The non-linear component is "wrapped" in the input space, where it is applied. .. This function writes a new file (it does not modify the input files in place). Parameters ---------- moving : ImageFile An object describing a moving image. fname : list of str Output filename for each input file of the moving image (since images can be encoded over multiple volumes) like : ImageFile An object describing the target space inv : bool, default=False True if we are warping the fixed image to the moving space. In the case, `moving` should be a `FixedImageFile` and `like` a `MovingImageFile`. Else it should be a `MovingImageFile` and `'like` a `FixedImageFile`. lin : (4, 4) tensor, optional Linear (or rather affine) transformation nonlin : dict, optional Non-linear displacement field, with keys: disp : (..., 3) tensor Displacement field (in voxels) affine : (4, 4) tensor Orientation matrix of the displacement field interpolation : int, default=1 bound : str, default='dct2' extrapolate : bool, default = False device : default='cpu' """ nonlin = nonlin or dict(disp=None, affine=None) prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate) moving_affine = moving.affine.to(device) fixed_affine = like.affine.to(device) if inv: # affine-corrected fixed space if lin is not None: fix2lin = affine_matmul(lin, fixed_affine) else: fix2lin = fixed_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param to moving nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(moving_affine, fix2lin) g = affine_grid(g, like.shape) else: # affine-corrected moving space if lin is not None: mov2nlin = affine_matmul(lin, moving_affine) else: mov2nlin = moving_affine if nonlin['disp'] is not None: # fixed voxels to param voxels (warps param to fixed) fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine) if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape): g = smalldef(nonlin['disp'].to(device)) else: g = affine_grid(fix2nlin, like.shape) g += pull_grid(nonlin['disp'].to(device), g) # param voxels to moving voxels (warps moving to fixed) nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device)) g = affine_matvec(nlin2mov, g) else: g = affine_lmdiv(mov2nlin, fixed_affine) g = affine_grid(g, like.shape) if moving.type == 'labels': prm['interpolation'] = 0 for file, ofname in zip(moving.files, fname): if verbose: print(f'Resliced: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=True, device=device) dat = dat.reshape([*file.shape, file.channels]) if g is not None: dat = utils.movedim(dat, -1, 0) dat = pull(dat, g, **prm) dat = utils.movedim(dat, 0, -1) io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
def main(options): # find readout direction f = io.map(options.echoes[0]) affine, shape = f.affine, f.shape readout = get_readout(options.direction, affine, shape, options.verbose) if not options.reversed: reversed_echoes = options.synth else: reversed_echoes = options.reversed # do EPIC fit = epic(options.echoes, reverse_echoes=reversed_echoes, fieldmap=options.fieldmap, extrapolate=options.extrapolate, bandwidth=options.bandwidth, polarity=options.polarity, readout=readout, slicewise=options.slicewise, lam=options.penalty, max_iter=options.maxiter, tol=options.tolerance, verbose=options.verbose, device=get_device(options.gpu)) # save volumes input, output = options.echoes, options.output if len(output) != len(input): if len(output) == 1: if '{base}' in output[0]: output = [output[0]] * len(input) elif len(output) != len(fit): raise ValueError(f'There should be either one output file, ' f'or as many output files as input files, ' f'or as many output files as echoes. Got ' f'{len(output)} output files, {len(input)} ' f'input files, and {len(fit)} echoes.') if len(output) == 1: dir, base, ext = py.fileparts(input[0]) output = output[0] if '{n}' in output: for n, echo in enumerate(fit): out = output.format(dir=dir, sep=os.sep, base=base, ext=ext, n=n) io.savef(echo, out, like=input[0]) else: output = output.format(dir=dir, sep=os.sep, base=base, ext=ext) io.savef(torch.movedim(fit, 0, -1), output, like=input[0]) elif len(output) == len(input): for i, (inp, out) in enumerate(zip(input, output)): dir, base, ext = py.fileparts(inp) out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=i) ne = [*io.map(inp).shape, 1][3] io.savef(fit[:ne].movedim(0, -1), out, like=inp) fit = fit[ne:] else: assert len(output) == len(fit) dir, base, ext = py.fileparts(input[0]) for n, (echo, out) in enumerate(zip(fit, output)): out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=n) io.savef(echo, out, like=input[0])
def main_fit(options): """ Estimate a displacement field from opposite polarity images """ device = get_device(options.gpu) # map input files f0 = io.map(options.pos_file) f1 = io.map(options.neg_file) dim = f0.affine.shape[-1] - 1 # map mask fm = None if options.mask: fm = io.map(options.mask) # detect readout direction readout = get_readout(options.readout, f0.affine, f0.shape[-dim:]) # detect penalty penalty_type = 'bending' penalties = options.penalty if penalties and isinstance(penalties[-1], str): *penalties, penalty_type = penalties if not penalties: penalties = [1] if penalty_type[0] == 'b': penalty_type = 'bending' elif penalty_type[0] == 'm': penalty_type = 'membrane' else: raise ValueError('Unknown penalty type', penalty_type) downs = options.downsample max_iter = options.max_iter tolerance = options.tolerance nb_levels = max(len(penalties), len(max_iter), len(tolerance), len(downs)) penalties = py.make_list(penalties, nb_levels) tolerance = py.make_list(tolerance, nb_levels) max_iter = py.make_list(max_iter, nb_levels) downs = py.make_list(downs, nb_levels) # load d00 = f0.fdata(device='cpu') d11 = f1.fdata(device='cpu') dmask = fm.fdata(device='cpu') if fm else None # fit vel = mask = None aff = last_aff = f0.affine last_dwn = None for penalty, n, tol, dwn in zip(penalties, max_iter, tolerance, downs): if dwn != last_dwn: d0, aff = downsample(d00.to(device), f0.affine, dwn) d1, _ = downsample(d11.to(device), f1.affine, dwn) vx = spatial.voxel_size(aff) if vel is not None: vel = upsample_vel(vel, last_aff, aff, d0.shape[-dim:], readout) last_aff = aff if fm: mask, _ = downsample(dmask.to(device), f1.affine, dwn) last_dwn = dwn scl = py.prod(d00.shape) / py.prod(d0.shape) penalty = penalty * scl kernel = get_kernel(options.kernel, aff, d0.shape[-dim:], dwn) # prepare loss if options.loss == 'mse': prm0, _ = estimate_noise(d0) prm1, _ = estimate_noise(d1) sd = ((prm0['sd'].log() + prm1['sd'].log())/2).exp() print(sd.item()) loss = MSE(lam=1/(sd*sd), dim=dim) elif options.loss == 'lncc': loss = LNCC(dim=dim, patch=kernel) elif options.loss == 'lgmm': if options.bins == 1: loss = LNCC(dim=dim, patch=kernel) else: loss = LGMMH(dim=dim, patch=kernel, bins=options.bins) elif options.loss == 'gmm': if options.bins == 1: loss = NCC(dim=dim) else: loss = GMMH(dim=dim, bins=options.bins) else: loss = NCC(dim=dim) # fit vel = topup_fit(d0, d1, loss=loss, dim=readout, vx=vx, ndim=dim, model=('svf' if options.diffeo else 'smalldef'), lam=penalty, penalty=penalty_type, vel=vel, modulation=options.modulation, max_iter=n, tolerance=tol, verbose=options.verbose, mask=mask) del d0, d1, d00, d11 # upsample vel = upsample_vel(vel, aff, f0.affine, f0.shape[-dim:], readout) # save dir, base, ext = py.fileparts(options.pos_file) fname = options.output fname = fname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext) io.savef(vel, fname, like=options.pos_file, dtype='float32')
def savef(self, fname, *args, **kwargs): """Save to disk""" io.savef(self.volume, fname, *args, **kwargs)
def write_outputs(z, prm, options): # prepare filenames ref_native = options.input[0] ref_mni = options.tpm[0] if options.tpm else path_spm_prior() format_dict = get_format_dict(ref_native, options.output) # move channels to back backend = utils.backend(z) if (options.nobias_nat or options.nobias_mni or options.nobias_wrp or options.all_nat or options.all_mni or options.all_wrp): dat, _, affine = get_data(options.input, options.mask, None, 3, **backend) # --- native space ------------------------------------------------- if options.prob_nat or options.all_nat: fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.nat ->', fname) io.savef(torch.movedim(z, 0, -1), fname, like=ref_native, dtype='float32') if options.labels_nat or options.all_nat: fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.nat ->', fname) io.save(z.argmax(0), fname, like=ref_native, dtype='int16') if (options.bias_nat or options.all_nat) and options.bias: bias = prm['bias'] fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.nat ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.nat.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, dtype='float32') del bias if (options.nobias_nat or options.all_nat) and options.bias: nobias = dat * prm['bias'] fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.nat ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.nat.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1) del nobias if (options.warp_nat or options.all_nat) and options.warp: warp = prm['warp'] fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.nat ->', fname) io.savef(warp, fname, like=ref_native, dtype='float32') # --- MNI space ---------------------------------------------------- if options.tpm is False: # No template -> no MNI space return fref = io.map(ref_mni) mni_affine, mni_shape = fref.affine, fref.shape[:3] dat_affine = io.map(ref_native).affine mni_affine = mni_affine.to(**backend) dat_affine = dat_affine.to(**backend) prm_affine = prm['affine'].to(**backend) dat_affine = prm_affine @ dat_affine if options.mni_vx: vx = spatial.voxel_size(mni_affine) scl = vx / options.mni_vx mni_affine, mni_shape = spatial.affine_resize(mni_affine, mni_shape, scl, anchor='f') if options.prob_mni or options.labels_mni or options.all_mni: z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape) if options.prob_mni: fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.mni ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni: fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.mni ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_mni or options.nobias_mni or options.all_mni): bias = spatial.reslice(prm['bias'], dat_affine, mni_affine, mni_shape, interpolation=3, prefilter=False, bound='dct2') if options.bias_mni or options.all_mni: fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.mni ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.mni.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_mni or options.all_mni: nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape) nobias *= bias fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.mni ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.mni.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp or options.bias_wrp or options.nobias_wrp or options.all_mni or options.all_wrp) need_iwarp = need_iwarp and options.warp if not need_iwarp: return iwarp = spatial.grid_inv(prm['warp'], type='disp') iwarp = iwarp.movedim(-1, 0) iwarp = spatial.reslice(iwarp, dat_affine, mni_affine, mni_shape, interpolation=2, bound='dft', extrapolate=True) iwarp = iwarp.movedim(0, -1) iaff = mni_affine.inverse() @ dat_affine iwarp = linalg.matvec(iaff[:3, :3], iwarp) if (options.warp_mni or options.all_mni) and options.warp: fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.mni ->', fname) io.savef(iwarp, fname, like=ref_native, affine=mni_affine, dtype='float32') # --- Warped space ------------------------------------------------- iwarp = spatial.add_identity_grid_(iwarp) iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp) if options.prob_wrp or options.labels_wrp or options.all_wrp: z_mni = spatial.grid_pull(z, iwarp) if options.prob_mni or options.all_wrp: fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.wrp ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni or options.all_wrp: fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.wrp ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_wrp or options.nobias_wrp or options.all_wrp): bias = spatial.grid_pull(prm['bias'], iwarp, interpolation=3, prefilter=False, bound='dct2') if options.bias_wrp or options.all_wrp: fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.wrp ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.wrp.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_wrp or options.all_wrp: nobias = spatial.grid_pull(dat, iwarp) nobias *= bias fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.wrp ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.wrp.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias