def exp(self, velocity, displacement=False): """Generate a deformation grid from tangent parameters. Parameters ---------- velocity : (batch, *spatial, nb_dim) Stationary velocity field displacement : bool, default=False Return a displacement field (voxel to shift) rather than a transformation field (voxel to voxel). Returns ------- grid : (batch, *spatial, nb_dim) Deformation grid (transformation or displacement). """ # generate grid shape = velocity.shape[1:-1] velocity_small = self.resize(velocity) grid = self.velexp(velocity_small) grid = self.resize(grid, output_shape=shape, factor=None) if not displacement: grid = spatial.add_identity_grid_(grid) return grid
def forward(self, source, target, source_seg=None, target_seg=None, *, _loss=None, _metric=None): vel = self.unet(torch.cat([source, target], dim=1)) if hasattr(self, 'resize_vel'): vel = self.resize_vel(vel) grid = self.exp(vel) if hasattr(self, 'resize_grid'): grid = self.resize_grid(vel, output_shape=source.shape[2:]) grid = spatial.add_identity_grid_(grid) deformed_source = self.pull(source, grid) if source_seg is not None: deformed_source_seg = self.pull(source_seg, grid) else: deformed_source_seg = None # compute loss and metrics self.compute(_loss, _metric, image=[deformed_source, target], velocity=[vel], segmentation=[deformed_source_seg, target_seg]) if source_seg is None: return deformed_source, vel, grid else: return deformed_source, deformed_source_seg, vel, grid
def derivatives_distortion(contrast, distortion, intercept, decay, opt, do_grad=True): """Compute the gradient and Hessian of the distortion field. Parameters ---------- contrast : (nb_echo, *obs_shape) GradientEchoMulti A single echo series (with the same weighting) distortion : ParameterizedDeformation A model of distortions caused by B0 inhomogeneities. intercept : (*recon_shape) ParameterMap Log-intercept of the contrast decay : (*recon_shape) ParameterMap Exponential decay opt : Options Returns ------- crit : () tensor Log-likelihood grad : (*shape, 3) tensor hess : (*shape, 6) tensor """ dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) obs_shape = contrast.volume.shape[1:] recon_shape = intercept.volume.shape aff = core.linalg.lmdiv(intercept.affine, contrast.affine) aff = aff.to(**backend) lam = 1 / contrast.noise df = contrast.dof chi = opt.likelihood[0].lower() == 'c' # pull parameter maps to observed space grid = smart_grid(aff, obs_shape, recon_shape) inter = smart_pull(intercept.fdata(**backend), grid) slope = smart_pull(decay.fdata(**backend), grid) readout = contrast.readout if opt.distortion.te_scaling != 'pre': grid_up, grid_down = distortion.exp2( add_identity=not opt.distortion.te_scaling) else: grid_up = grid_down = None crit = 0 grad = torch.zeros(obs_shape + (3, ), **backend) if do_grad else None hess = torch.zeros(obs_shape + (6, ), **backend) if do_grad else None te0 = 0 for e, echo in enumerate(contrast): te = echo.te te0 = te0 or te blip = echo.blip or (2 * (e % 2) - 1) grid_blip = grid_up if blip > 0 else grid_down vscl = te / te0 if opt.distortion.te_scaling == 'pre': vexp = distortion.iexp if blip < 0 else distortion.exp grid_blip = vexp(add_identity=True, alpha=vscl) elif opt.distortion.te_scaling: grid_blip = spatial.add_identity_grid_(vscl * grid_blip) # compute residuals dat = echo.fdata(**backend, rand=True, cache=False) # observed fit = recon_fit(inter, slope, te) # fitted if do_grad and isinstance(distortion, DenseDeformation): # D(fit) o phi gfit = smart_grad(fit, grid_blip, bound='dft', extrapolate=True) fit = smart_pull(fit, grid_blip, bound='dft', extrapolate=True) msk = get_mask_missing(dat, fit) # mask of missing values if do_grad and isinstance(distortion, SVFDeformation): # D(fit o phi) gfit = spatial.diff(fit, bound='dft', dim=[-3, -2, -1]) gfit.masked_fill_(msk.unsqueeze(-1), 0) dat.masked_fill_(msk, 0) fit.masked_fill_(msk, 0) msk = msk.bitwise_not_() if chi: crit1, res = nll_chi(dat, fit, msk, lam, df) else: crit1, res = nll_gauss(dat, fit, msk, lam) del dat, fit, msk crit += crit1 if do_grad: g1 = res.unsqueeze(-1).mul(gfit) h1 = torch.zeros_like(hess) if readout is None: h1[..., :3] = gfit.square() h1[..., 3] = gfit[..., 0] * gfit[..., 1] h1[..., 4] = gfit[..., 0] * gfit[..., 2] h1[..., 5] = gfit[..., 1] * gfit[..., 2] else: h1[..., readout] = gfit[..., readout].square() # propagate backward if isinstance(distortion, SVFDeformation): vel = distortion.volume if opt.distortion.te_scaling == 'pre': vel = ((-vscl) * vel) if blip < 0 else (vscl * vel) elif blip < 0: vel = -vel g1, h1 = spatial.exp_backward(vel, g1, h1, steps=distortion.steps) alpha_g = alpha_h = lam alpha_g = alpha_g * blip if opt.distortion.te_scaling == 'pre': alpha_g = alpha_g * vscl alpha_h = alpha_h * (vscl * vscl) grad.add_(g1, alpha=alpha_g) hess.add_(h1, alpha=alpha_h) if not do_grad: return crit if readout is None: mask_nan_(grad) mask_nan_(hess[:-3], 1e-3) # diagonal mask_nan_(hess[-3:]) # off-diagonal else: grad = grad[..., readout] hess = hess[..., readout] mask_nan_(grad) mask_nan_(hess) return crit, grad, hess
def derivatives_parameters(contrast, distortion, intercept, decay, opt, do_grad=True): """Compute the gradient and Hessian of the parameter maps with respect to one contrast. Parameters ---------- contrast : (nb_echo, *obs_shape) GradientEchoMulti A single echo series (with the same weighting) distortion : ParameterizedDeformation A model of distortions caused by B0 inhomogeneities. intercept : (*recon_shape) ParameterMap Log-intercept of the contrast decay : (*recon_shape) ParameterMap Exponential decay opt : Options do_grad : bool, default=True Returns ------- crit : () tensor Log-likelihood grad : (2, *recon_shape) tensor, if `do_grad` Gradient with respect to: [0] intercept [1] decay hess : (3, *recon_shape) tensor, if `do_grad` Hessian with respect to: [0] intercept ** 2 [1] decay ** 2 [2] intercept * decay """ dtype = opt.backend.dtype device = opt.backend.device backend = dict(dtype=dtype, device=device) obs_shape = contrast.volume.shape[1:] recon_shape = intercept.volume.shape aff = core.linalg.lmdiv(intercept.affine, contrast.affine) aff = aff.to(**backend) lam = 1 / contrast.noise df = contrast.dof chi = opt.likelihood[0].lower() == 'c' # pull parameter maps to observed space grid = smart_grid(aff, obs_shape, recon_shape) inter = smart_pull(intercept.fdata(**backend), grid) slope = smart_pull(decay.fdata(**backend), grid) if distortion and opt.distortion.te_scaling != 'pre': grid_up, grid_down = distortion.exp2( add_identity=not opt.distortion.te_scaling) else: grid_up = grid_down = None crit = 0 grad = torch.zeros((2, ) + obs_shape, **backend) if do_grad else None hess = torch.zeros((3, ) + obs_shape, **backend) if do_grad else None te0 = 0 for e, echo in enumerate(contrast): te = echo.te te0 = te0 or te blip = echo.blip or (2 * (e % 2) - 1) grid_blip = grid_up if blip > 0 else grid_down if distortion: vscl = te / te0 if opt.distortion.te_scaling == 'pre': vexp = distortion.iexp if blip < 0 else distortion.exp grid_blip = vexp(add_identity=True, alpha=vscl) elif opt.distortion.te_scaling == 'post': grid_blip = spatial.add_identity_grid_(vscl * grid_blip) # compute residuals dat = echo.fdata(**backend, rand=True, cache=False) fit = recon_fit(inter, slope, te) pull_fit = smart_pull(fit, grid_blip, bound='dft', extrapolate=True) msk = get_mask_missing(dat, pull_fit) dat.masked_fill_(msk, 0) pull_fit.masked_fill_(msk, 0) msk = msk.bitwise_not_() if chi: crit1, res = nll_chi(dat, pull_fit, msk, lam, df) else: crit1, res = nll_gauss(dat, pull_fit, msk, lam) del dat, pull_fit crit += crit1 if do_grad: msk = msk.to(fit.dtype) if grid_blip is not None: res0 = res res = smart_push(res0, grid_blip, bound='dft', extrapolate=True) abs_res = smart_push(res0.abs_(), grid_blip, bound='dft', extrapolate=True) abs_res.mul_(fit) msk = smart_push(msk, grid_blip, bound='dft', extrapolate=True) del res0 # ---------------------------------------------------------- # compute gradient and (majorised) Hessian in observed space # # grad[inter] = lam * fit * res # grad[decay] = -te * lam * fit * res # hess[inter**2] = lam * fit * (fit + abs(res)) # hess[decay**2] = (te*te) * lam * fit * (fit + abs(res)) # hess[inter*decay] = -te * lam * fit * fit # # I tried to put that into an "accumulation" function but it # does super weird stuff, so I keep it in the main loop. I am # saving a few allocations here so I think it's faster than # torchscript. # ---------------------------------------------------------- res.mul_(fit) grad[0].add_(res, alpha=lam) grad[1].add_(res, alpha=-te * lam) if grid_blip is None: abs_res = res.abs_() fit2 = fit.mul_(fit).mul_(msk) del msk hess[2].add_(fit2, alpha=-te * lam) fit2.add_(abs_res) hess[0].add_(fit2, alpha=lam) hess[1].add_(fit2, alpha=lam * (te * te)) del res, fit, abs_res, fit2 if not do_grad: return crit mask_nan_(grad) mask_nan_(hess[:-1], 1e-3) # diagonal mask_nan_(hess[-1]) # off-diagonal # push gradient and Hessian to recon space grad = smart_push(grad, grid, recon_shape) hess = smart_push(hess, grid, recon_shape) return crit, grad, hess
def add_identity_(self, disp): disp = utils.movedim(disp, self.displacement_dim, -1) disp = spatial.add_identity_grid_(disp.unsqueeze(-1)).squeeze(-1) disp = utils.movedim(disp, -1, self.displacement_dim) return disp
def main_apply(options): """ Unwarp distorted images using a pre-computed 1d displacement field. """ device = get_device(options.gpu) # detect readout direction if options.file_pos: f0 = io.map(options.file_pos[0]) else: f0 = io.map(options.file_neg[0]) dim = f0.affine.shape[-1] - 1 readout = get_readout(options.readout, f0.affine, f0.shape[-dim:]) def do_apply(fnames, phi, jac): """Correct files with a given polarity""" for fname in fnames: dir, base, ext = py.fileparts(fname) ofname = options.output ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext) if options.verbose: print(f'unwarp {fname} \n' f' -> {ofname}') f = io.map(fname) d = f.fdata(device=device) d = utils.movedim(d, readout, -1) d = _deform1d(d, phi) if jac is not None: d *= jac d = utils.movedim(d, -1, readout) io.savef(d, ofname, like=fname) # load and apply vel = io.loadf(options.dist_file, device=device) vel = utils.movedim(vel, readout, -1) if options.file_pos: if options.diffeo: phi, *jac = spatial.exp1d_forward(vel, bound='dct2', jacobian=options.modulation) jac = jac[0] if jac else None else: phi = vel.clone() jac = None if options.modulation: jac = spatial.diff1d(phi, dim=readout, bound='dct2', side='c') jac += 1 phi = spatial.add_identity_grid_(phi.unsqueeze(-1)).squeeze(-1) do_apply(options.file_pos, phi, jac) if options.file_neg: if options.diffeo: phi, *jac = spatial.exp1d_forward(-vel, bound='dct2', jacobian=options.modulation) jac = jac[0] if jac else None else: phi = -vel jac = None if options.modulation: jac = spatial.diff1d(phi, dim=readout, bound='dct2', side='c') jac += 1 phi = spatial.add_identity_grid_(phi.unsqueeze(-1)).squeeze(-1) do_apply(options.file_neg, phi, jac)
def write_outputs(z, prm, options): # prepare filenames ref_native = options.input[0] ref_mni = options.tpm[0] if options.tpm else path_spm_prior() format_dict = get_format_dict(ref_native, options.output) # move channels to back backend = utils.backend(z) if (options.nobias_nat or options.nobias_mni or options.nobias_wrp or options.all_nat or options.all_mni or options.all_wrp): dat, _, affine = get_data(options.input, options.mask, None, 3, **backend) # --- native space ------------------------------------------------- if options.prob_nat or options.all_nat: fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.nat ->', fname) io.savef(torch.movedim(z, 0, -1), fname, like=ref_native, dtype='float32') if options.labels_nat or options.all_nat: fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.nat ->', fname) io.save(z.argmax(0), fname, like=ref_native, dtype='int16') if (options.bias_nat or options.all_nat) and options.bias: bias = prm['bias'] fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.nat ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.nat.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, dtype='float32') del bias if (options.nobias_nat or options.all_nat) and options.bias: nobias = dat * prm['bias'] fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.nat ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.nat.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1) del nobias if (options.warp_nat or options.all_nat) and options.warp: warp = prm['warp'] fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.nat ->', fname) io.savef(warp, fname, like=ref_native, dtype='float32') # --- MNI space ---------------------------------------------------- if options.tpm is False: # No template -> no MNI space return fref = io.map(ref_mni) mni_affine, mni_shape = fref.affine, fref.shape[:3] dat_affine = io.map(ref_native).affine mni_affine = mni_affine.to(**backend) dat_affine = dat_affine.to(**backend) prm_affine = prm['affine'].to(**backend) dat_affine = prm_affine @ dat_affine if options.mni_vx: vx = spatial.voxel_size(mni_affine) scl = vx / options.mni_vx mni_affine, mni_shape = spatial.affine_resize(mni_affine, mni_shape, scl, anchor='f') if options.prob_mni or options.labels_mni or options.all_mni: z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape) if options.prob_mni: fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.mni ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni: fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.mni ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_mni or options.nobias_mni or options.all_mni): bias = spatial.reslice(prm['bias'], dat_affine, mni_affine, mni_shape, interpolation=3, prefilter=False, bound='dct2') if options.bias_mni or options.all_mni: fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.mni ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.mni.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_mni or options.all_mni: nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape) nobias *= bias fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.mni ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.mni.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp or options.bias_wrp or options.nobias_wrp or options.all_mni or options.all_wrp) need_iwarp = need_iwarp and options.warp if not need_iwarp: return iwarp = spatial.grid_inv(prm['warp'], type='disp') iwarp = iwarp.movedim(-1, 0) iwarp = spatial.reslice(iwarp, dat_affine, mni_affine, mni_shape, interpolation=2, bound='dft', extrapolate=True) iwarp = iwarp.movedim(0, -1) iaff = mni_affine.inverse() @ dat_affine iwarp = linalg.matvec(iaff[:3, :3], iwarp) if (options.warp_mni or options.all_mni) and options.warp: fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.mni ->', fname) io.savef(iwarp, fname, like=ref_native, affine=mni_affine, dtype='float32') # --- Warped space ------------------------------------------------- iwarp = spatial.add_identity_grid_(iwarp) iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp) if options.prob_wrp or options.labels_wrp or options.all_wrp: z_mni = spatial.grid_pull(z, iwarp) if options.prob_mni or options.all_wrp: fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.wrp ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni or options.all_wrp: fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.wrp ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_wrp or options.nobias_wrp or options.all_wrp): bias = spatial.grid_pull(prm['bias'], iwarp, interpolation=3, prefilter=False, bound='dct2') if options.bias_wrp or options.all_wrp: fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.wrp ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.wrp.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_wrp or options.all_wrp: nobias = spatial.grid_pull(dat, iwarp) nobias *= bias fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.wrp ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.wrp.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias