Example #1
0
def write_transforms(options):
    """Write transformations (affine and nonlin) on disk"""
    nonlin = None
    affine = None
    for trf in options.transformations:
        if isinstance(trf, struct.NonLinear):
            nonlin = trf
        else:
            affine = trf

    if affine:
        q = affine.dat
        B = affine.basis
        lin = linalg.expm(q, B)
        if torch.is_tensor(affine.shift):
            # include shift
            shift = affine.shift.to(dtype=lin.dtype, device=lin.device)
            eye = torch.eye(3, dtype=lin.dtype, device=lin.device)
            lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift)
        io.transforms.savef(lin.cpu(), affine.output, type=2)

    if nonlin:
        affine = nonlin.affine
        shape = nonlin.shape
        if isinstance(nonlin, struct.FFD):
            factor = [s/g for s, g in zip(shape, nonlin.dat.shape[:-1])]
            affine, _ = spatial.affine_resize(affine, shape, factor)
        io.volumes.savef(nonlin.dat.cpu(), nonlin.output, affine=affine.cpu())
Example #2
0
def read_info(options):
    """Load affine transforms and space info of other volumes"""
    def read_file(fname):
        o = struct.FileWithInfo()
        o.fname = fname
        o.dir = os.path.dirname(fname) or '.'
        o.base = os.path.basename(fname)
        o.base, o.ext = os.path.splitext(o.base)
        if o.ext in ('.gz', '.bz2'):
            zext = o.ext
            o.base, o.ext = os.path.splitext(o.base)
            o.ext += zext
        f = io.volumes.map(fname)
        o.float = nitype(f.dtype).is_floating_point
        o.shape = squeeze_to_nd(f.shape, dim=3, channels=1)
        o.channels = o.shape[-1]
        o.shape = o.shape[:3]
        o.affine = f.affine.float()
        return o

    def read_affine(fname):
        mat = io.transforms.loadf(fname).float()
        return squeeze_to_nd(mat, 0, 2)

    def read_field(fname):
        f = io.volumes.map(fname)
        return f.affine.float(), f.shape[:3]

    options.files = [read_file(file) for file in options.files]
    for trf in options.transformations:
        if isinstance(trf, struct.Linear):
            trf.affine = read_affine(trf.file)
        else:
            trf.affine, trf.shape = read_field(trf.file)
    if options.target:
        options.target = read_file(options.target)
        if options.voxel_size:
            options.voxel_size = utils.make_vector(
                options.voxel_size, 3, dtype=options.target.affine.dtype)
            factor = spatial.voxel_size(
                options.target.affine) / options.voxel_size
            options.target.affine, options.target.shape = \
                spatial.affine_resize(options.target.affine, options.target.shape,
                                      factor=factor, anchor='f')
Example #3
0
    def resize(self):
        affine, shape = spatial.affine_resize(self.affine0, self.shape0,
                                              1 / (2**(self.level - 1)))

        scl0 = spatial.voxel_size(self.affine0).prod()
        scl = spatial.voxel_size(affine).prod() / scl0
        self.lam_scale = scl

        for map in self.maps:
            map.volume = spatial.resize(map.volume[None, None, ...],
                                        shape=shape)[0, 0]
            map.affine = affine
        self.maps.affine = affine
        if self.rls is not None:
            if self.rls.dim() == len(shape):
                self.rls = spatial.resize(self.rls[None, None], hape=shape)[0,
                                                                            0]
            else:
                self.rls = spatial.resize(self.rls[None], shape=shape)[0]
            self.nll['rls'] = self.rls.reciprocal().sum(dtype=torch.double)
Example #4
0
def _get_level(level, aff0, shape0):
    return spatial.affine_resize(aff0, shape0, 1 / (2**(level - 1)))
Example #5
0
def _get_level(level, aff0, shape0):
    """Get shape and affine of a given resolution level"""
    return spatial.affine_resize(aff0, shape0, 1 / (2**(level - 1)))
Example #6
0
def load_transforms(s):
    """Initialize transforms"""
    device = torch.device(s.device)

    def reshape3d(dat, channels=None, dim=3):
        """Reshape as (*spatial) or (C, *spatial) or (*spatial, C).
        `channels` should be in ('first', 'last', None).
        """
        while len(dat.shape) > dim:
            if dat.shape[-1] == 1:
                dat = dat[..., 0]
                continue
            elif dat.shape[dim] == 1:
                dat = dat[:, :, :, 0, ...]
                continue
            else:
                break
        if len(dat.shape) > dim + bool(channels):
            raise ValueError('Too many channel dimensions')
        if channels and len(dat.shape) == dim:
            dat = dat[..., None]
        if channels == 'first':
            dat = utils.movedim(dat, -1, 0)
        return dat

    # compute mean space
    #   it is used to define the space of the nonlinear transform, but
    #   also to shift the center of rotation of the linear transform.
    all_affines = []
    all_shapes = []
    all_affines_fixed = []
    all_shapes_fixed = []
    for loss in s.losses:
        if isinstance(loss, struct.NoLoss):
            continue
        if getattr(loss, 'exclude', False):
            continue
        all_shapes_fixed.append(loss.fixed.shape)
        all_affines_fixed.append(loss.fixed.affine)
        all_shapes.append(loss.fixed.shape)
        all_affines.append(loss.fixed.affine)
        all_shapes.append(loss.moving.shape)
        all_affines.append(loss.moving.affine)
    affine0, shape0 = mean_space(all_affines, all_shapes,
                                 pad=s.pad, pad_unit=s.pad_unit)
    affinef, shapef = mean_space(all_affines_fixed, all_shapes_fixed,
                                 pad=s.pad, pad_unit=s.pad_unit)
    backend = dict(dtype=affine0.dtype, device=affine0.device)

    for trf in s.transformations:
        for reg in trf.losses:
            if isinstance(reg.factor, (list, tuple)):
                reg.factor = [f * trf.factor for f in reg.factor]
            else:
                reg.factor = reg.factor * trf.factor
        if isinstance(trf, struct.Linear):
            # Affine
            if isinstance(trf.init, str):
                trf.dat = io.transforms.loadf(trf.init, dtype=torch.float32,
                                              device=device)
            else:
                trf.dat = torch.zeros(trf.nb_prm(3), dtype=torch.float32,
                                      device=device)
            if trf.shift:
                shift = torch.as_tensor(shapef, **backend) * 0.5
                trf.shift = -spatial.affine_matvec(affinef, shift)
            else:
                trf.shift = 0.
        else:
            affine, shape = (affine0, shape0)
            trf.pyramid = list(sorted(trf.pyramid))
            max_level = max(trf.pyramid)
            factor = 2**(max_level-1)
            affine, shape = affine_resize(affine, shape, 1/factor)

            # FFD/Diffeo
            if isinstance(trf.init, str):
                f = io.volumes.map(trf.init)
                trf.dat = reshape3d(f.loadf(dtype=torch.float32, device=device),
                                    'last')
                if len(trf.dat) != trf.dim:
                    raise ValueError('Field should have 3 channels')
                factor = [int(s//g) for g, s in zip(trf.shape[:-1], shape)]
                trf.affine, trf.shape = affine_resize(trf.affine, trf.shape[:-1], factor)
            else:
                trf.dat = torch.zeros([*shape, trf.dim], dtype=torch.float32,
                                      device=device)
                trf.affine = affine
                trf.shape = shape
Example #7
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # 1) Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, struct.Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            trf.dat = spatial.exp(trf.dat[None],
                                  displacement=True,
                                  inverse=trf.inv)[0]
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, struct.Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.unit == 'mm':
                # convert mm displacement to vox displacement
                trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None])
                trf.dat = trf.dat[..., 0]
                trf.unit = 'vox'

    # 2) If the first transform is linear, compose it with the input
    #    orientation matrix
    if (options.transformations
            and isinstance(options.transformations[0], struct.Linear)):
        trf = options.transformations[0]
        for file in options.files:
            mat = file.affine.to(**backend)
            aff = trf.affine.to(**backend)
            file.affine = spatial.affine_lmdiv(aff, mat)
        options.transformations = options.transformations[1:]

    def build_from_target(affine, shape):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(affine.to(**backend), shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, struct.Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.order)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.order
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    # 3) If target is provided, we can build most of the transform once
    #    and just multiply it with a input-wise affine matrix.
    if options.target:
        grid = build_from_target(options.target.affine, options.target.shape)
        oaffine = options.target.affine

    # 4) Loop across input files
    opt_pull = dict(interpolation=options.interpolation,
                    bound=options.bound,
                    extrapolate=options.extrapolate)
    opt_coeff = dict(interpolation=options.interpolation,
                     bound=options.bound,
                     dim=3,
                     inplace=True)
    output = py.make_list(options.output, len(options.files))
    for file, ofname in zip(options.files, output):
        ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
        print(f'Reslicing:   {file.fname}\n' f'          -> {ofname}')
        dat = io.volumes.loadf(file.fname,
                               rand=options.interpolation > 0,
                               **backend)
        dat = dat.reshape([*file.shape, file.channels])
        dat = utils.movedim(dat, -1, 0)

        if not options.target:
            oaffine = file.affine
            oshape = file.shape
            if options.voxel_size:
                ovx = utils.make_vector(options.voxel_size,
                                        3,
                                        dtype=oaffine.dtype)
                factor = spatial.voxel_size(oaffine) / ovx
                oaffine, oshape = spatial.affine_resize(oaffine,
                                                        oshape,
                                                        factor=factor,
                                                        anchor='f')
            grid = build_from_target(oaffine, oshape)
        mat = file.affine.to(**backend)
        imat = spatial.affine_inv(mat)
        if options.prefilter:
            dat = spatial.spline_coeff_nd(dat, **opt_coeff)
        dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull)
        dat = utils.movedim(dat, 0, -1)

        io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
Example #8
0
def write_outputs(z, prm, options):

    # prepare filenames
    ref_native = options.input[0]
    ref_mni = options.tpm[0] if options.tpm else path_spm_prior()
    format_dict = get_format_dict(ref_native, options.output)

    # move channels to back
    backend = utils.backend(z)
    if (options.nobias_nat or options.nobias_mni or options.nobias_wrp
            or options.all_nat or options.all_mni or options.all_wrp):
        dat, _, affine = get_data(options.input, options.mask, None, 3,
                                  **backend)

    # --- native space -------------------------------------------------

    if options.prob_nat or options.all_nat:
        fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('prob.nat     ->', fname)
        io.savef(torch.movedim(z, 0, -1),
                 fname,
                 like=ref_native,
                 dtype='float32')

    if options.labels_nat or options.all_nat:
        fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('labels.nat   ->', fname)
        io.save(z.argmax(0), fname, like=ref_native, dtype='int16')

    if (options.bias_nat or options.all_nat) and options.bias:
        bias = prm['bias']
        fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('bias.nat     ->', fname)
            io.savef(torch.movedim(bias, 0, -1),
                     fname,
                     like=ref_native,
                     dtype='float32')
        else:
            for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'bias.nat.{c+1}   ->', fname)
                io.savef(bias1, fname, like=ref1, dtype='float32')
        del bias

    if (options.nobias_nat or options.all_nat) and options.bias:
        nobias = dat * prm['bias']
        fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('nobias.nat   ->', fname)
            io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native)
        else:
            for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'nobias.nat.{c+1} ->', fname)
                io.savef(nobias1, fname, like=ref1)
        del nobias

    if (options.warp_nat or options.all_nat) and options.warp:
        warp = prm['warp']
        fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.nat     ->', fname)
        io.savef(warp, fname, like=ref_native, dtype='float32')

    # --- MNI space ----------------------------------------------------
    if options.tpm is False:
        # No template -> no MNI space
        return

    fref = io.map(ref_mni)
    mni_affine, mni_shape = fref.affine, fref.shape[:3]
    dat_affine = io.map(ref_native).affine
    mni_affine = mni_affine.to(**backend)
    dat_affine = dat_affine.to(**backend)
    prm_affine = prm['affine'].to(**backend)
    dat_affine = prm_affine @ dat_affine
    if options.mni_vx:
        vx = spatial.voxel_size(mni_affine)
        scl = vx / options.mni_vx
        mni_affine, mni_shape = spatial.affine_resize(mni_affine,
                                                      mni_shape,
                                                      scl,
                                                      anchor='f')

    if options.prob_mni or options.labels_mni or options.all_mni:
        z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape)
        if options.prob_mni:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.mni     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.mni   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_mni or options.nobias_mni
                         or options.all_mni):
        bias = spatial.reslice(prm['bias'],
                               dat_affine,
                               mni_affine,
                               mni_shape,
                               interpolation=3,
                               prefilter=False,
                               bound='dct2')

        if options.bias_mni or options.all_mni:
            fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.mni     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.mni.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_mni or options.all_mni:
            nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape)
            nobias *= bias
            fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.mni   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.mni.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias

    need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp
                  or options.bias_wrp or options.nobias_wrp or options.all_mni
                  or options.all_wrp)
    need_iwarp = need_iwarp and options.warp
    if not need_iwarp:
        return

    iwarp = spatial.grid_inv(prm['warp'], type='disp')
    iwarp = iwarp.movedim(-1, 0)
    iwarp = spatial.reslice(iwarp,
                            dat_affine,
                            mni_affine,
                            mni_shape,
                            interpolation=2,
                            bound='dft',
                            extrapolate=True)
    iwarp = iwarp.movedim(0, -1)
    iaff = mni_affine.inverse() @ dat_affine
    iwarp = linalg.matvec(iaff[:3, :3], iwarp)

    if (options.warp_mni or options.all_mni) and options.warp:
        fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.mni     ->', fname)
        io.savef(iwarp,
                 fname,
                 like=ref_native,
                 affine=mni_affine,
                 dtype='float32')

    # --- Warped space -------------------------------------------------
    iwarp = spatial.add_identity_grid_(iwarp)
    iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp)

    if options.prob_wrp or options.labels_wrp or options.all_wrp:
        z_mni = spatial.grid_pull(z, iwarp)
        if options.prob_mni or options.all_wrp:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.wrp     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni or options.all_wrp:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.wrp   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_wrp or options.nobias_wrp
                         or options.all_wrp):
        bias = spatial.grid_pull(prm['bias'],
                                 iwarp,
                                 interpolation=3,
                                 prefilter=False,
                                 bound='dct2')
        if options.bias_wrp or options.all_wrp:
            fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.wrp     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.wrp.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_wrp or options.all_wrp:
            nobias = spatial.grid_pull(dat, iwarp)
            nobias *= bias
            fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.wrp   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.wrp.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias