def write_transforms(options): """Write transformations (affine and nonlin) on disk""" nonlin = None affine = None for trf in options.transformations: if isinstance(trf, struct.NonLinear): nonlin = trf else: affine = trf if affine: q = affine.dat B = affine.basis lin = linalg.expm(q, B) if torch.is_tensor(affine.shift): # include shift shift = affine.shift.to(dtype=lin.dtype, device=lin.device) eye = torch.eye(3, dtype=lin.dtype, device=lin.device) lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift) io.transforms.savef(lin.cpu(), affine.output, type=2) if nonlin: affine = nonlin.affine shape = nonlin.shape if isinstance(nonlin, struct.FFD): factor = [s/g for s, g in zip(shape, nonlin.dat.shape[:-1])] affine, _ = spatial.affine_resize(affine, shape, factor) io.volumes.savef(nonlin.dat.cpu(), nonlin.output, affine=affine.cpu())
def read_info(options): """Load affine transforms and space info of other volumes""" def read_file(fname): o = struct.FileWithInfo() o.fname = fname o.dir = os.path.dirname(fname) or '.' o.base = os.path.basename(fname) o.base, o.ext = os.path.splitext(o.base) if o.ext in ('.gz', '.bz2'): zext = o.ext o.base, o.ext = os.path.splitext(o.base) o.ext += zext f = io.volumes.map(fname) o.float = nitype(f.dtype).is_floating_point o.shape = squeeze_to_nd(f.shape, dim=3, channels=1) o.channels = o.shape[-1] o.shape = o.shape[:3] o.affine = f.affine.float() return o def read_affine(fname): mat = io.transforms.loadf(fname).float() return squeeze_to_nd(mat, 0, 2) def read_field(fname): f = io.volumes.map(fname) return f.affine.float(), f.shape[:3] options.files = [read_file(file) for file in options.files] for trf in options.transformations: if isinstance(trf, struct.Linear): trf.affine = read_affine(trf.file) else: trf.affine, trf.shape = read_field(trf.file) if options.target: options.target = read_file(options.target) if options.voxel_size: options.voxel_size = utils.make_vector( options.voxel_size, 3, dtype=options.target.affine.dtype) factor = spatial.voxel_size( options.target.affine) / options.voxel_size options.target.affine, options.target.shape = \ spatial.affine_resize(options.target.affine, options.target.shape, factor=factor, anchor='f')
def resize(self): affine, shape = spatial.affine_resize(self.affine0, self.shape0, 1 / (2**(self.level - 1))) scl0 = spatial.voxel_size(self.affine0).prod() scl = spatial.voxel_size(affine).prod() / scl0 self.lam_scale = scl for map in self.maps: map.volume = spatial.resize(map.volume[None, None, ...], shape=shape)[0, 0] map.affine = affine self.maps.affine = affine if self.rls is not None: if self.rls.dim() == len(shape): self.rls = spatial.resize(self.rls[None, None], hape=shape)[0, 0] else: self.rls = spatial.resize(self.rls[None], shape=shape)[0] self.nll['rls'] = self.rls.reciprocal().sum(dtype=torch.double)
def _get_level(level, aff0, shape0): return spatial.affine_resize(aff0, shape0, 1 / (2**(level - 1)))
def _get_level(level, aff0, shape0): """Get shape and affine of a given resolution level""" return spatial.affine_resize(aff0, shape0, 1 / (2**(level - 1)))
def load_transforms(s): """Initialize transforms""" device = torch.device(s.device) def reshape3d(dat, channels=None, dim=3): """Reshape as (*spatial) or (C, *spatial) or (*spatial, C). `channels` should be in ('first', 'last', None). """ while len(dat.shape) > dim: if dat.shape[-1] == 1: dat = dat[..., 0] continue elif dat.shape[dim] == 1: dat = dat[:, :, :, 0, ...] continue else: break if len(dat.shape) > dim + bool(channels): raise ValueError('Too many channel dimensions') if channels and len(dat.shape) == dim: dat = dat[..., None] if channels == 'first': dat = utils.movedim(dat, -1, 0) return dat # compute mean space # it is used to define the space of the nonlinear transform, but # also to shift the center of rotation of the linear transform. all_affines = [] all_shapes = [] all_affines_fixed = [] all_shapes_fixed = [] for loss in s.losses: if isinstance(loss, struct.NoLoss): continue if getattr(loss, 'exclude', False): continue all_shapes_fixed.append(loss.fixed.shape) all_affines_fixed.append(loss.fixed.affine) all_shapes.append(loss.fixed.shape) all_affines.append(loss.fixed.affine) all_shapes.append(loss.moving.shape) all_affines.append(loss.moving.affine) affine0, shape0 = mean_space(all_affines, all_shapes, pad=s.pad, pad_unit=s.pad_unit) affinef, shapef = mean_space(all_affines_fixed, all_shapes_fixed, pad=s.pad, pad_unit=s.pad_unit) backend = dict(dtype=affine0.dtype, device=affine0.device) for trf in s.transformations: for reg in trf.losses: if isinstance(reg.factor, (list, tuple)): reg.factor = [f * trf.factor for f in reg.factor] else: reg.factor = reg.factor * trf.factor if isinstance(trf, struct.Linear): # Affine if isinstance(trf.init, str): trf.dat = io.transforms.loadf(trf.init, dtype=torch.float32, device=device) else: trf.dat = torch.zeros(trf.nb_prm(3), dtype=torch.float32, device=device) if trf.shift: shift = torch.as_tensor(shapef, **backend) * 0.5 trf.shift = -spatial.affine_matvec(affinef, shift) else: trf.shift = 0. else: affine, shape = (affine0, shape0) trf.pyramid = list(sorted(trf.pyramid)) max_level = max(trf.pyramid) factor = 2**(max_level-1) affine, shape = affine_resize(affine, shape, 1/factor) # FFD/Diffeo if isinstance(trf.init, str): f = io.volumes.map(trf.init) trf.dat = reshape3d(f.loadf(dtype=torch.float32, device=device), 'last') if len(trf.dat) != trf.dim: raise ValueError('Field should have 3 channels') factor = [int(s//g) for g, s in zip(trf.shape[:-1], shape)] trf.affine, trf.shape = affine_resize(trf.affine, trf.shape[:-1], factor) else: trf.dat = torch.zeros([*shape, trf.dim], dtype=torch.float32, device=device) trf.affine = affine trf.shape = shape
def write_data(options): backend = dict(dtype=torch.float32, device=options.device) # 1) Pre-exponentiate velocities for trf in options.transformations: if isinstance(trf, struct.Velocity): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] trf.dat = spatial.exp(trf.dat[None], displacement=True, inverse=trf.inv)[0] trf.inv = False trf.order = 1 elif isinstance(trf, struct.Displacement): f = io.volumes.map(trf.file) trf.affine = f.affine trf.shape = squeeze_to_nd(f.shape, 3, 1) trf.dat = f.fdata(**backend).reshape(trf.shape) trf.shape = trf.shape[:3] if trf.unit == 'mm': # convert mm displacement to vox displacement trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None]) trf.dat = trf.dat[..., 0] trf.unit = 'vox' # 2) If the first transform is linear, compose it with the input # orientation matrix if (options.transformations and isinstance(options.transformations[0], struct.Linear)): trf = options.transformations[0] for file in options.files: mat = file.affine.to(**backend) aff = trf.affine.to(**backend) file.affine = spatial.affine_lmdiv(aff, mat) options.transformations = options.transformations[1:] def build_from_target(affine, shape): """Compose all transformations, starting from the final orientation""" grid = spatial.affine_grid(affine.to(**backend), shape) for trf in reversed(options.transformations): if isinstance(trf, struct.Linear): grid = spatial.affine_matvec(trf.affine.to(**backend), grid) else: mat = trf.affine.to(**backend) if trf.inv: vx0 = spatial.voxel_size(mat) vx1 = spatial.voxel_size(affine.to(**backend)) factor = vx0 / vx1 disp, mat = spatial.resize_grid(trf.dat[None], factor, affine=mat, interpolation=trf.order) disp = spatial.grid_inv(disp[0], type='disp') order = 1 else: disp = trf.dat order = trf.order imat = spatial.affine_inv(mat) grid = spatial.affine_matvec(imat, grid) grid += helpers.pull_grid(disp, grid, interpolation=order) grid = spatial.affine_matvec(mat, grid) return grid # 3) If target is provided, we can build most of the transform once # and just multiply it with a input-wise affine matrix. if options.target: grid = build_from_target(options.target.affine, options.target.shape) oaffine = options.target.affine # 4) Loop across input files opt_pull = dict(interpolation=options.interpolation, bound=options.bound, extrapolate=options.extrapolate) opt_coeff = dict(interpolation=options.interpolation, bound=options.bound, dim=3, inplace=True) output = py.make_list(options.output, len(options.files)) for file, ofname in zip(options.files, output): ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext) print(f'Reslicing: {file.fname}\n' f' -> {ofname}') dat = io.volumes.loadf(file.fname, rand=options.interpolation > 0, **backend) dat = dat.reshape([*file.shape, file.channels]) dat = utils.movedim(dat, -1, 0) if not options.target: oaffine = file.affine oshape = file.shape if options.voxel_size: ovx = utils.make_vector(options.voxel_size, 3, dtype=oaffine.dtype) factor = spatial.voxel_size(oaffine) / ovx oaffine, oshape = spatial.affine_resize(oaffine, oshape, factor=factor, anchor='f') grid = build_from_target(oaffine, oshape) mat = file.affine.to(**backend) imat = spatial.affine_inv(mat) if options.prefilter: dat = spatial.spline_coeff_nd(dat, **opt_coeff) dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull) dat = utils.movedim(dat, 0, -1) io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
def write_outputs(z, prm, options): # prepare filenames ref_native = options.input[0] ref_mni = options.tpm[0] if options.tpm else path_spm_prior() format_dict = get_format_dict(ref_native, options.output) # move channels to back backend = utils.backend(z) if (options.nobias_nat or options.nobias_mni or options.nobias_wrp or options.all_nat or options.all_mni or options.all_wrp): dat, _, affine = get_data(options.input, options.mask, None, 3, **backend) # --- native space ------------------------------------------------- if options.prob_nat or options.all_nat: fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.nat ->', fname) io.savef(torch.movedim(z, 0, -1), fname, like=ref_native, dtype='float32') if options.labels_nat or options.all_nat: fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.nat ->', fname) io.save(z.argmax(0), fname, like=ref_native, dtype='int16') if (options.bias_nat or options.all_nat) and options.bias: bias = prm['bias'] fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.nat ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.nat.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, dtype='float32') del bias if (options.nobias_nat or options.all_nat) and options.bias: nobias = dat * prm['bias'] fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.nat ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.nat.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1) del nobias if (options.warp_nat or options.all_nat) and options.warp: warp = prm['warp'] fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.nat ->', fname) io.savef(warp, fname, like=ref_native, dtype='float32') # --- MNI space ---------------------------------------------------- if options.tpm is False: # No template -> no MNI space return fref = io.map(ref_mni) mni_affine, mni_shape = fref.affine, fref.shape[:3] dat_affine = io.map(ref_native).affine mni_affine = mni_affine.to(**backend) dat_affine = dat_affine.to(**backend) prm_affine = prm['affine'].to(**backend) dat_affine = prm_affine @ dat_affine if options.mni_vx: vx = spatial.voxel_size(mni_affine) scl = vx / options.mni_vx mni_affine, mni_shape = spatial.affine_resize(mni_affine, mni_shape, scl, anchor='f') if options.prob_mni or options.labels_mni or options.all_mni: z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape) if options.prob_mni: fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.mni ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni: fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.mni ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_mni or options.nobias_mni or options.all_mni): bias = spatial.reslice(prm['bias'], dat_affine, mni_affine, mni_shape, interpolation=3, prefilter=False, bound='dct2') if options.bias_mni or options.all_mni: fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.mni ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.mni.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_mni or options.all_mni: nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape) nobias *= bias fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.mni ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.mni.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp or options.bias_wrp or options.nobias_wrp or options.all_mni or options.all_wrp) need_iwarp = need_iwarp and options.warp if not need_iwarp: return iwarp = spatial.grid_inv(prm['warp'], type='disp') iwarp = iwarp.movedim(-1, 0) iwarp = spatial.reslice(iwarp, dat_affine, mni_affine, mni_shape, interpolation=2, bound='dft', extrapolate=True) iwarp = iwarp.movedim(0, -1) iaff = mni_affine.inverse() @ dat_affine iwarp = linalg.matvec(iaff[:3, :3], iwarp) if (options.warp_mni or options.all_mni) and options.warp: fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('warp.mni ->', fname) io.savef(iwarp, fname, like=ref_native, affine=mni_affine, dtype='float32') # --- Warped space ------------------------------------------------- iwarp = spatial.add_identity_grid_(iwarp) iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp) if options.prob_wrp or options.labels_wrp or options.all_wrp: z_mni = spatial.grid_pull(z, iwarp) if options.prob_mni or options.all_wrp: fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('prob.wrp ->', fname) io.savef(torch.movedim(z_mni, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') if options.labels_mni or options.all_wrp: fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}' fname = fname.format(**format_dict) if options.verbose > 0: print('labels.wrp ->', fname) io.save(z_mni.argmax(0), fname, like=ref_native, affine=mni_affine, dtype='int16') del z_mni if options.bias and (options.bias_wrp or options.nobias_wrp or options.all_wrp): bias = spatial.grid_pull(prm['bias'], iwarp, interpolation=3, prefilter=False, bound='dct2') if options.bias_wrp or options.all_wrp: fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('bias.wrp ->', fname) io.savef(torch.movedim(bias, 0, -1), fname, like=ref_native, affine=mni_affine, dtype='float32') else: for c, (bias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'bias.wrp.{c+1} ->', fname) io.savef(bias1, fname, like=ref1, affine=mni_affine, dtype='float32') if options.nobias_wrp or options.all_wrp: nobias = spatial.grid_pull(dat, iwarp) nobias *= bias fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}' if len(options.input) == 1: fname = fname.format(**format_dict) if options.verbose > 0: print('nobias.wrp ->', fname) io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native, affine=mni_affine) else: for c, (nobias1, ref1) in enumerate(zip(bias, options.input)): format_dict1 = get_format_dict(ref1, options.output) fname = fname.format(**format_dict1) if options.verbose > 0: print(f'nobias.wrp.{c+1} ->', fname) io.savef(nobias1, fname, like=ref1, affine=mni_affine) del nobias del bias