Beispiel #1
0
 def build_from_target(target):
     """Compose all transformations, starting from the final orientation"""
     grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
     for trf in reversed(options.transformations):
         if isinstance(trf, Linear):
             grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
         else:
             mat = trf.affine.to(**backend)
             if trf.inv:
                 vx0 = spatial.voxel_size(mat)
                 vx1 = spatial.voxel_size(target.affine.to(**backend))
                 factor = vx0 / vx1
                 disp, mat = spatial.resize_grid(trf.dat[None],
                                                 factor,
                                                 affine=mat,
                                                 interpolation=trf.spline)
                 disp = spatial.grid_inv(disp[0], type='disp')
                 order = 1
             else:
                 disp = trf.dat
                 order = trf.spline
             imat = spatial.affine_inv(mat)
             grid = spatial.affine_matvec(imat, grid)
             grid += helpers.pull_grid(disp, grid, interpolation=order)
             grid = spatial.affine_matvec(mat, grid)
     return grid
Beispiel #2
0
 def iexp(self, v=None, jacobian=False, add_identity=False,
          cache_result=False, recompute=True):
     """Exponentiate inverse transform"""
     if v is None:
         v = self.dat.dat
     if recompute or self._icache is None:
         grid = spatial.grid_inv(v, type='disp', **self.penalty)
     else:
         grid = self._icache
     if cache_result:
         self._icache = grid
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
     if add_identity:
         grid = self.add_identity(grid)
     return (grid, jac) if jacobian else grid
Beispiel #3
0
 def exp2(self, v=None, jacobian=False, add_identity=False,
         cache_result=False, recompute=True):
     """Exponentiate both forward and inverse transforms"""
     if v is None:
         v = self.dat.dat
     grid = v
     if recompute or self._icache is None:
         igrid = spatial.grid_inv(v, type='disp', **self.penalty)
     else:
         igrid = self._icache
     if cache_result:
         self._icache = igrid
     if jacobian:
         jac = spatial.grid_jacobian(grid, type='displacement')
         ijac = spatial.grid_jacobian(igrid, type='displacement')
     if add_identity:
         grid = self.add_identity(grid)
         igrid = self.add_identity(igrid)
     return (grid, igrid, jac, ijac) if jacobian else (grid, igrid)
Beispiel #4
0
def write_data(options):

    device = torch.device(options.device)
    backend = dict(dtype=torch.float, device='cpu')

    need_inv = False
    for loss in options.losses:
        if loss.fixed and (loss.fixed.resliced or loss.fixed.updated):
            need_inv = True
            break

    # affine matrix
    lin = None
    for trf in options.transformations:
        if isinstance(trf, struct.Linear):
            q = trf.dat.to(**backend)
            B = trf.basis.to(**backend)
            lin = linalg.expm(q, B)
            if torch.is_tensor(trf.shift):
                # include shift
                shift = trf.shift.to(**backend)
                eye = torch.eye(3, **backend)
                lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift)
            break

    # non-linear displacement field
    d = None
    id = None
    d_aff = None
    for trf in options.transformations:
        if isinstance(trf, struct.FFD):
            d = trf.dat.to(**backend)
            d = ffd_exp(d, trf.shape, returns='disp')
            if need_inv:
                id = grid_inv(d)
            d_aff = trf.affine.to(**backend)
            break
        elif isinstance(trf, struct.Diffeo):
            d = trf.dat.to(**backend)
            if need_inv:
                id = spatial.exp(d[None], displacement=True, inverse=True)[0]
            d = spatial.exp(d[None], displacement=True)[0]
            d_aff = trf.affine.to(**backend)
            break

    # loop over image pairs
    for match in options.losses:

        moving = match.moving
        fixed = match.fixed
        prm = dict(interpolation=moving.interpolation,
                   bound=moving.bound,
                   extrapolate=moving.extrapolate,
                   device='cpu',
                   verbose=options.verbose)
        nonlin = dict(disp=d, affine=d_aff)
        if moving.updated:
            update(moving, moving.updated, lin=lin, nonlin=nonlin, **prm)
        if moving.resliced:
            reslice(moving, moving.resliced, like=fixed, lin=lin, nonlin=nonlin, **prm)
        if not fixed:
            continue
        prm = dict(interpolation=fixed.interpolation,
                   bound=fixed.bound,
                   extrapolate=fixed.extrapolate,
                   device='cpu',
                   verbose=options.verbose)
        nonlin = dict(disp=id, affine=d_aff)
        if fixed.updated:
            update(fixed, fixed.updated, inv=True, lin=lin, nonlin=nonlin, **prm)
        if fixed.resliced:
            reslice(fixed, fixed.resliced, inv=True, like=moving, lin=lin, nonlin=nonlin, **prm)
Beispiel #5
0
def write_outputs(z, prm, options):

    # prepare filenames
    ref_native = options.input[0]
    ref_mni = options.tpm[0] if options.tpm else path_spm_prior()
    format_dict = get_format_dict(ref_native, options.output)

    # move channels to back
    backend = utils.backend(z)
    if (options.nobias_nat or options.nobias_mni or options.nobias_wrp
            or options.all_nat or options.all_mni or options.all_wrp):
        dat, _, affine = get_data(options.input, options.mask, None, 3,
                                  **backend)

    # --- native space -------------------------------------------------

    if options.prob_nat or options.all_nat:
        fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('prob.nat     ->', fname)
        io.savef(torch.movedim(z, 0, -1),
                 fname,
                 like=ref_native,
                 dtype='float32')

    if options.labels_nat or options.all_nat:
        fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('labels.nat   ->', fname)
        io.save(z.argmax(0), fname, like=ref_native, dtype='int16')

    if (options.bias_nat or options.all_nat) and options.bias:
        bias = prm['bias']
        fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('bias.nat     ->', fname)
            io.savef(torch.movedim(bias, 0, -1),
                     fname,
                     like=ref_native,
                     dtype='float32')
        else:
            for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'bias.nat.{c+1}   ->', fname)
                io.savef(bias1, fname, like=ref1, dtype='float32')
        del bias

    if (options.nobias_nat or options.all_nat) and options.bias:
        nobias = dat * prm['bias']
        fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('nobias.nat   ->', fname)
            io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native)
        else:
            for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'nobias.nat.{c+1} ->', fname)
                io.savef(nobias1, fname, like=ref1)
        del nobias

    if (options.warp_nat or options.all_nat) and options.warp:
        warp = prm['warp']
        fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.nat     ->', fname)
        io.savef(warp, fname, like=ref_native, dtype='float32')

    # --- MNI space ----------------------------------------------------
    if options.tpm is False:
        # No template -> no MNI space
        return

    fref = io.map(ref_mni)
    mni_affine, mni_shape = fref.affine, fref.shape[:3]
    dat_affine = io.map(ref_native).affine
    mni_affine = mni_affine.to(**backend)
    dat_affine = dat_affine.to(**backend)
    prm_affine = prm['affine'].to(**backend)
    dat_affine = prm_affine @ dat_affine
    if options.mni_vx:
        vx = spatial.voxel_size(mni_affine)
        scl = vx / options.mni_vx
        mni_affine, mni_shape = spatial.affine_resize(mni_affine,
                                                      mni_shape,
                                                      scl,
                                                      anchor='f')

    if options.prob_mni or options.labels_mni or options.all_mni:
        z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape)
        if options.prob_mni:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.mni     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.mni   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_mni or options.nobias_mni
                         or options.all_mni):
        bias = spatial.reslice(prm['bias'],
                               dat_affine,
                               mni_affine,
                               mni_shape,
                               interpolation=3,
                               prefilter=False,
                               bound='dct2')

        if options.bias_mni or options.all_mni:
            fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.mni     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.mni.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_mni or options.all_mni:
            nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape)
            nobias *= bias
            fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.mni   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.mni.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias

    need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp
                  or options.bias_wrp or options.nobias_wrp or options.all_mni
                  or options.all_wrp)
    need_iwarp = need_iwarp and options.warp
    if not need_iwarp:
        return

    iwarp = spatial.grid_inv(prm['warp'], type='disp')
    iwarp = iwarp.movedim(-1, 0)
    iwarp = spatial.reslice(iwarp,
                            dat_affine,
                            mni_affine,
                            mni_shape,
                            interpolation=2,
                            bound='dft',
                            extrapolate=True)
    iwarp = iwarp.movedim(0, -1)
    iaff = mni_affine.inverse() @ dat_affine
    iwarp = linalg.matvec(iaff[:3, :3], iwarp)

    if (options.warp_mni or options.all_mni) and options.warp:
        fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.mni     ->', fname)
        io.savef(iwarp,
                 fname,
                 like=ref_native,
                 affine=mni_affine,
                 dtype='float32')

    # --- Warped space -------------------------------------------------
    iwarp = spatial.add_identity_grid_(iwarp)
    iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp)

    if options.prob_wrp or options.labels_wrp or options.all_wrp:
        z_mni = spatial.grid_pull(z, iwarp)
        if options.prob_mni or options.all_wrp:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.wrp     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni or options.all_wrp:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.wrp   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_wrp or options.nobias_wrp
                         or options.all_wrp):
        bias = spatial.grid_pull(prm['bias'],
                                 iwarp,
                                 interpolation=3,
                                 prefilter=False,
                                 bound='dct2')
        if options.bias_wrp or options.all_wrp:
            fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.wrp     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.wrp.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_wrp or options.all_wrp:
            nobias = spatial.grid_pull(dat, iwarp)
            nobias *= bias
            fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.wrp   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.wrp.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias