def build_from_target(target): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(target.affine.to(**backend), target.shape) for trf in reversed(options.transformations): if isinstance(trf, Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(target.affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.spline) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.spline imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid
def iexp(self, v=None, jacobian=False, add_identity=False, cache_result=False, recompute=True): """Exponentiate inverse transform""" if v is None: v = self.dat.dat if recompute or self._icache is None: grid = spatial.grid_inv(v, type='disp', **self.penalty) else: grid = self._icache if cache_result: self._icache = grid if jacobian: jac = spatial.grid_jacobian(grid, type='displacement') if add_identity: grid = self.add_identity(grid) return (grid, jac) if jacobian else grid
def exp2(self, v=None, jacobian=False, add_identity=False, cache_result=False, recompute=True): """Exponentiate both forward and inverse transforms""" if v is None: v = self.dat.dat grid = v if recompute or self._icache is None: igrid = spatial.grid_inv(v, type='disp', **self.penalty) else: igrid = self._icache if cache_result: self._icache = igrid if jacobian: jac = spatial.grid_jacobian(grid, type='displacement') ijac = spatial.grid_jacobian(igrid, type='displacement') if add_identity: grid = self.add_identity(grid) igrid = self.add_identity(igrid) return (grid, igrid, jac, ijac) if jacobian else (grid, igrid)
def write_data(options): device = torch.device(options.device) backend = dict(dtype=torch.float, device='cpu') need_inv = False for loss in options.losses: if loss.fixed and (loss.fixed.resliced or loss.fixed.updated): need_inv = True break # affine matrix lin = None for trf in options.transformations: if isinstance(trf, struct.Linear): q = trf.dat.to(**backend) B = trf.basis.to(**backend) lin = linalg.expm(q, B) if torch.is_tensor(trf.shift): # include shift shift = trf.shift.to(**backend) eye = torch.eye(3, **backend) lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift) break # non-linear displacement field d = None id = None d_aff = None for trf in options.transformations: if isinstance(trf, struct.FFD): d = trf.dat.to(**backend) d = ffd_exp(d, trf.shape, returns='disp') if need_inv: id = grid_inv(d) d_aff = trf.affine.to(**backend) break elif isinstance(trf, struct.Diffeo): d = trf.dat.to(**backend) if need_inv: id = spatial.exp(d[None], displacement=True, inverse=True)[0] d = spatial.exp(d[None], displacement=True)[0] d_aff = trf.affine.to(**backend) break # loop over image pairs for match in options.losses: moving = match.moving fixed = match.fixed prm = dict(interpolation=moving.interpolation, bound=moving.bound, extrapolate=moving.extrapolate, device='cpu', verbose=options.verbose) nonlin = dict(disp=d, affine=d_aff) if moving.updated: update(moving, moving.updated, lin=lin, nonlin=nonlin, **prm) if moving.resliced: reslice(moving, moving.resliced, like=fixed, lin=lin, nonlin=nonlin, **prm) if not fixed: continue prm = dict(interpolation=fixed.interpolation, bound=fixed.bound, extrapolate=fixed.extrapolate, device='cpu', verbose=options.verbose) nonlin = dict(disp=id, affine=d_aff) if fixed.updated: update(fixed, fixed.updated, inv=True, lin=lin, nonlin=nonlin, **prm) if fixed.resliced: reslice(fixed, fixed.resliced, inv=True, like=moving, lin=lin, nonlin=nonlin, **prm)
def write_outputs(z, prm, options): # prepare filenames ref_native = options.input[0] ref_mni = options.tpm[0] if options.tpm else path_spm_prior() format_dict = get_format_dict(ref_native, options.output) # move channels to back backend = utils.backend(z) if (options.nobias_nat or options.nobias_mni or options.nobias_wrp or options.all_nat or options.all_mni or options.all_wrp): dat, _, affine = get_data(options.input, options.mask, None, 3, **backend) # --- native space ------------------------------------------------- if options.prob_nat or options.all_nat: fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.nat ->', fname) io.savef(torch.movedim(z, 0, -1), fname, like=ref_native, dtype='float32') if options.labels_nat or options.all_nat: fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.nat ->', fname) io.save(z.argmax(0), fname, like=ref_native, dtype='int16') if (options.bias_nat or options.all_nat) and options.bias: bias = prm['bias'] fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.nat ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.nat.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, dtype='float32') del bias if (options.nobias_nat or options.all_nat) and options.bias: nobias = dat * prm['bias'] fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.nat ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.nat.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1) del nobias if (options.warp_nat or options.all_nat) and options.warp: warp = prm['warp'] fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.nat ->', fname) io.savef(warp, fname, like=ref_native, dtype='float32') # --- MNI space ---------------------------------------------------- if options.tpm is False: # No template -> no MNI space return fref = io.map(ref_mni) mni_affine, mni_shape = fref.affine, fref.shape[:3] dat_affine = io.map(ref_native).affine mni_affine = mni_affine.to(**backend) dat_affine = dat_affine.to(**backend) prm_affine = prm['affine'].to(**backend) dat_affine = prm_affine @ dat_affine if options.mni_vx: vx = spatial.voxel_size(mni_affine) scl = vx / options.mni_vx mni_affine, mni_shape = spatial.affine_resize(mni_affine, mni_shape, scl, anchor='f') if options.prob_mni or options.labels_mni or options.all_mni: z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape) if options.prob_mni: fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.mni ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni: fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.mni ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_mni or options.nobias_mni or options.all_mni): bias = spatial.reslice(prm['bias'], dat_affine, mni_affine, mni_shape, interpolation=3, prefilter=False, bound='dct2') if options.bias_mni or options.all_mni: fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.mni ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.mni.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_mni or options.all_mni: nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape) nobias *= bias fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.mni ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.mni.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp or options.bias_wrp or options.nobias_wrp or options.all_mni or options.all_wrp) need_iwarp = need_iwarp and options.warp if not need_iwarp: return iwarp = spatial.grid_inv(prm['warp'], type='disp') iwarp = iwarp.movedim(-1, 0) iwarp = spatial.reslice(iwarp, dat_affine, mni_affine, mni_shape, interpolation=2, bound='dft', extrapolate=True) iwarp = iwarp.movedim(0, -1) iaff = mni_affine.inverse() @ dat_affine iwarp = linalg.matvec(iaff[:3, :3], iwarp) if (options.warp_mni or options.all_mni) and options.warp: fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.mni ->', fname) io.savef(iwarp, fname, like=ref_native, affine=mni_affine, dtype='float32') # --- Warped space ------------------------------------------------- iwarp = spatial.add_identity_grid_(iwarp) iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp) if options.prob_wrp or options.labels_wrp or options.all_wrp: z_mni = spatial.grid_pull(z, iwarp) if options.prob_mni or options.all_wrp: fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.wrp ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni or options.all_wrp: fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.wrp ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_wrp or options.nobias_wrp or options.all_wrp): bias = spatial.grid_pull(prm['bias'], iwarp, interpolation=3, prefilter=False, bound='dct2') if options.bias_wrp or options.all_wrp: fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.wrp ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.wrp.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_wrp or options.all_wrp: nobias = spatial.grid_pull(dat, iwarp) nobias *= bias fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.wrp ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.wrp.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias