Exemplo n.º 1
0
 def pull(q, vel):
     grid = spatial.exp(vel)
     aff = core.linalg.expm(q, basis)
     aff = spatial.affine_matmul(aff, target_aff)
     aff = spatial.affine_lmdiv(source_aff, aff)
     grid = spatial.affine_matvec(aff, grid)
     moved = spatial.grid_pull(source, grid, **pull_opt)
     return moved
Exemplo n.º 2
0
    def forward(self, velocity, **kwargs):
        """

        Parameters
        ----------
        velocity (tensor) : velocity field with shape (batch, *spatial, dim).
        **kwargs : all parameters of the module can be overridden at call time.

        Returns
        -------
        forward (tensor, if `forward is True`) : forward displacement
            (if `displacement is True`) or transformation (if `displacement
            is False`) field, with shape (batch, *spatial, dim)
        inverse (tensor, if `inverse is True`) : forward displacement
            (if `displacement is True`) or transformation (if `displacement
            is False`) field, with shape (batch, *spatial, dim)

        """
        fwd = kwargs.get('fwd', self.forward)
        inv = kwargs.get('inverse', self.inv)
        opt = {
            'steps': kwargs.get('steps', self.steps),
            'interpolation': kwargs.get('interpolation', self.interpolation),
            'bound': kwargs.get('bound', self.bound),
            'displacement': kwargs.get('displacement', self.displacement),
            'inplace': False,  # kwargs.get('inplace', self.inplace)
        }

        output = []
        if fwd:
            y = spatial.exp(velocity, inverse=False, **opt)
            output.append(y)
        if inv:
            iy = spatial.exp(velocity, inverse=True, **opt)
            output.append(iy)

        return output if len(output) > 1 else \
               output[0] if len(output) == 1 else \
               None
Exemplo n.º 3
0
    def forward(self, velocity, fwd=None, inv=None):
        """

        Parameters
        ----------
        velocity :(batch, *spatial, dim) tensor
            Stationary velocity field.
        fwd : bool, default=self.fwd
        inv : bool, default=self.inv

        Returns
        -------
        forward : (batch, *spatial, dim) tensor, if `fwd is True`
            Forward displacement (if `displacement is True`) or
            transformation (if `displacement is False`) field.
        inverse : (batch, *spatial, dim) tensor, if `inv is True`
            Inverse displacement (if `displacement is True`) or
            transformation (if `displacement is False`) field.

        """
        fwd = fwd if fwd is not None else self.fwd
        inv = inv if inv is not None else self.inv
        opt = dict(steps=self.steps,
                   interpolation=self.interpolation,
                   bound=self.bound,
                   displacement=self.displacement,
                   anagrad=self.anagrad)

        output = []
        if fwd:
            y = spatial.exp(velocity, inverse=False, **opt)
            output.append(y)
        if inv:
            iy = spatial.exp(velocity, inverse=True, **opt)
            output.append(iy)

        return output if len(output) > 1 else \
               output[0] if len(output) == 1 else \
               None
Exemplo n.º 4
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            trf.dat = spatial.exp(trf.dat[None],
                                  displacement=True,
                                  inverse=trf.inv)[0]
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.unit == 'mm':
                # convert mm displacement to vox displacement
                trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None])
                trf.dat = trf.dat[..., 0]
                trf.unit = 'vox'

    def build_from_target(target):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(target.affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.spline)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.spline
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    if options.target:
        # If target is provided, we build a dense transformation field
        grid = build_from_target(options.target)
        oaffine = options.target.affine
        if options.output_unit[0] == 'v':
            grid = spatial.affine_matvec(spatial.affine_inv(oaffine), grid)
            grid = grid - spatial.identity_grid(grid.shape[:-1],
                                                **utils.backend(grid))
        else:
            grid = grid - spatial.affine_grid(
                oaffine.to(**utils.backend(grid)), grid.shape[:-1])
        io.volumes.savef(grid,
                         options.output.format(ext='.nii.gz'),
                         affine=oaffine)
    else:
        if len(options.transformations) > 1:
            raise RuntimeError('Something weird happened: '
                               'multiple transforms and no target')
        io.transforms.savef(options.transformations[0].affine,
                            options.output.format(ext='.lta'))
Exemplo n.º 5
0
def vexp(inp,
         type='displacement',
         unit='voxel',
         inverse=False,
         bound='dft',
         steps=8,
         device=None,
         output=None):
    """Exponentiate a stationary velocity fields.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    type : {'displacement', 'transformation'}, default='displacement'
        Whether to return a displacement field (coord-to-shift) or a
        transformation field (coord-to-coord).
    unit : {'voxel', 'mm'}, default='voxel'
        Whether to return displacement/coordinates in voxel or in mm.
        If mm, the input orientation matrix is used to convert voxels to mm.
    inverse : bool, default=False
        Whether to return the inverse field.
    bound : str, default='dft'
        Boundary conditions.
    steps : int, default=8
        Number of scaling and squaring steps.
    device : str, optional
        Device to use.
    output : str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.vexp{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.

    Returns
    -------
    output : str or (tensor, tensor)
        If the input is a path, the output path is returned.
        Else, the output tensor and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(device=device), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.vexp{ext}'
        dir, base, ext = py.fileparts(fname)
    else:
        if torch.is_tensor(inp):
            inp = (inp.clone(), spatial.affine_default(shape=inp.shape[:3]))
    dat, aff = inp
    dat = dat.to(device=device)
    aff = aff.to(device=device)

    # exponentiate
    dat = spatial.exp(dat[None],
                      inverse=inverse,
                      steps=steps,
                      bound=bound,
                      inplace=True,
                      displacement=(type.lower()[0] == 'd'))[0]
    if unit == 'mm':
        # if type.lower()[0] == 'd':
        #     vx = spatial.voxel_size(aff)
        #     dat *= vx
        # else:
        dat = spatial.affine_matvec(aff, dat)

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff.cpu())
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff.cpu())

    if is_file:
        return output
    else:
        return dat, aff
Exemplo n.º 6
0
def smart_exp(vel, **kwargs):
    """spatial.exp that accepts None vel"""
    if vel is not None:
        vel = spatial.exp(vel, **kwargs)
    return vel
Exemplo n.º 7
0
def write_data(options):

    device = torch.device(options.device)
    backend = dict(dtype=torch.float, device='cpu')

    need_inv = False
    for loss in options.losses:
        if loss.fixed and (loss.fixed.resliced or loss.fixed.updated):
            need_inv = True
            break

    # affine matrix
    lin = None
    for trf in options.transformations:
        if isinstance(trf, struct.Linear):
            q = trf.dat.to(**backend)
            B = trf.basis.to(**backend)
            lin = linalg.expm(q, B)
            if torch.is_tensor(trf.shift):
                # include shift
                shift = trf.shift.to(**backend)
                eye = torch.eye(3, **backend)
                lin[:-1, -1] += torch.matmul(lin[:-1, :-1] - eye, shift)
            break

    # non-linear displacement field
    d = None
    id = None
    d_aff = None
    for trf in options.transformations:
        if isinstance(trf, struct.FFD):
            d = trf.dat.to(**backend)
            d = ffd_exp(d, trf.shape, returns='disp')
            if need_inv:
                id = grid_inv(d)
            d_aff = trf.affine.to(**backend)
            break
        elif isinstance(trf, struct.Diffeo):
            d = trf.dat.to(**backend)
            if need_inv:
                id = spatial.exp(d[None], displacement=True, inverse=True)[0]
            d = spatial.exp(d[None], displacement=True)[0]
            d_aff = trf.affine.to(**backend)
            break

    # loop over image pairs
    for match in options.losses:

        moving = match.moving
        fixed = match.fixed
        prm = dict(interpolation=moving.interpolation,
                   bound=moving.bound,
                   extrapolate=moving.extrapolate,
                   device='cpu',
                   verbose=options.verbose)
        nonlin = dict(disp=d, affine=d_aff)
        if moving.updated:
            update(moving, moving.updated, lin=lin, nonlin=nonlin, **prm)
        if moving.resliced:
            reslice(moving, moving.resliced, like=fixed, lin=lin, nonlin=nonlin, **prm)
        if not fixed:
            continue
        prm = dict(interpolation=fixed.interpolation,
                   bound=fixed.bound,
                   extrapolate=fixed.extrapolate,
                   device='cpu',
                   verbose=options.verbose)
        nonlin = dict(disp=id, affine=d_aff)
        if fixed.updated:
            update(fixed, fixed.updated, inv=True, lin=lin, nonlin=nonlin, **prm)
        if fixed.resliced:
            reslice(fixed, fixed.resliced, inv=True, like=moving, lin=lin, nonlin=nonlin, **prm)
Exemplo n.º 8
0
    def forward():
        """Forward pass up to the loss"""

        loss = 0

        # affine matrix
        A = None
        for trf in options.transformations:
            trf.update()
            if isinstance(trf, struct.Linear):
                q = trf.optdat.to(**backend)
                # print(q.tolist())
                B = trf.basis.to(**backend)
                A = linalg.expm(q, B)
                if torch.is_tensor(trf.shift):
                    # include shift
                    shift = trf.shift.to(**backend)
                    eye = torch.eye(options.dim, **backend)
                    A = A.clone()  # needed because expm is a custom autograd.Function
                    A[:-1, -1] += torch.matmul(A[:-1, :-1] - eye, shift)
                for loss1 in trf.losses:
                    loss += loss1.call(q)
                break

        # non-linear displacement field
        d = None
        d_aff = None
        for trf in options.transformations:
            if not trf.isfree():
                continue
            if isinstance(trf, struct.FFD):
                d = trf.dat.to(**backend)
                d = ffd_exp(d, trf.shape, returns='disp')
                for loss1 in trf.losses:
                    loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break
            elif isinstance(trf, struct.Diffeo):
                d = trf.dat.to(**backend)
                if not trf.smalldef:
                    # penalty on velocity fields
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d = spatial.exp(d[None], displacement=True)[0]
                if trf.smalldef:
                    # penalty on exponentiated transform
                    for loss1 in trf.losses:
                        loss += loss1.call(d)
                d_aff = trf.affine.to(**backend)
                break

        # loop over image pairs
        for match in options.losses:
            if not match.fixed:
                continue
            nb_levels = len(match.fixed.dat)
            prm = dict(interpolation=match.moving.interpolation,
                       bound=match.moving.bound,
                       extrapolate=match.moving.extrapolate)
            # loop over pyramid levels
            for moving, fixed in zip(match.moving.dat, match.fixed.dat):
                moving_dat, moving_aff = moving
                fixed_dat, fixed_aff = fixed

                moving_dat = moving_dat.to(**backend)
                moving_aff = moving_aff.to(**backend)
                fixed_dat = fixed_dat.to(**backend)
                fixed_aff = fixed_aff.to(**backend)

                # affine-corrected moving space
                if A is not None:
                    Ms = affine_matmul(A, moving_aff)
                else:
                    Ms = moving_aff

                if d is not None:
                    # fixed to param
                    Mt = affine_lmdiv(d_aff, fixed_aff)
                    if samespace(Mt, d.shape[:-1], fixed_dat.shape[1:]):
                        g = smalldef(d)
                    else:
                        g = affine_grid(Mt, fixed_dat.shape[1:])
                        g = g + pull_grid(d, g)
                    # param to moving
                    Ms = affine_lmdiv(Ms, d_aff)
                    g = affine_matvec(Ms, g)
                else:
                    # fixed to moving
                    Mt = fixed_aff
                    Ms = affine_lmdiv(Ms, Mt)
                    g = affine_grid(Ms, fixed_dat.shape[1:])

                # pull moving image
                warped_dat = pull(moving_dat, g, **prm)
                loss += match.call(warped_dat, fixed_dat) / float(nb_levels)

                # import matplotlib.pyplot as plt
                # plt.subplot(1, 2, 1)
                # plt.imshow(fixed_dat[0, :, :, fixed_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.subplot(1, 2, 2)
                # plt.imshow(warped_dat[0, :, :, warped_dat.shape[-1]//2].detach())
                # plt.axis('off')
                # plt.show()

        return loss
Exemplo n.º 9
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # 1) Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, struct.Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            trf.dat = spatial.exp(trf.dat[None],
                                  displacement=True,
                                  inverse=trf.inv)[0]
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, struct.Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]

    # 2) If the first transform is linear, compose it with the input
    #    orientation matrix
    if (options.transformations
            and isinstance(options.transformations[0], struct.Linear)):
        trf = options.transformations[0]
        for file in options.files:
            mat = file.affine.to(**backend)
            aff = trf.affine.to(**backend)
            file.affine = spatial.affine_lmdiv(aff, mat)
        options.transformations = options.transformations[1:]

    def build_from_target(target):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, struct.Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(target.affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.order)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.order
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    # 3) If target is provided, we can build most of the transform once
    #    and just multiply it with a input-wise affine matrix.
    if options.target:
        grid = build_from_target(options.target)
        oaffine = options.target.affine

    # 4) Loop across input files
    opt = dict(interpolation=options.interpolation,
               bound=options.bound,
               extrapolate=options.extrapolate)
    output = utils.make_list(options.output, len(options.files))
    for file, ofname in zip(options.files, output):
        ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
        print(f'Reslicing:   {file.fname}\n' f'          -> {ofname}')
        dat = io.volumes.loadf(file.fname, rand=True, **backend)
        dat = dat.reshape([*file.shape, file.channels])
        dat = utils.movedim(dat, -1, 0)

        if not options.target:
            grid = build_from_target(file)
            oaffine = file.affine
        mat = file.affine.to(**backend)
        imat = spatial.affine_inv(mat)
        dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt)
        dat = utils.movedim(dat, 0, -1)

        io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)
Exemplo n.º 10
0
def write_data(options):

    backend = dict(dtype=torch.float32, device=options.device)

    # 1) Pre-exponentiate velocities
    for trf in options.transformations:
        if isinstance(trf, struct.Velocity):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.json:
                with open(trf.json) as f:
                    prm = json.load(f)
                prm['voxel_size'] = spatial.voxel_size(trf.affine)
                trf.dat = spatial.shoot(trf.dat[None],
                                        displacement=True,
                                        return_inverse=trf.inv)
                if trf.inv:
                    trf.dat = trf.dat[-1]
            else:
                trf.dat = spatial.exp(trf.dat[None],
                                      displacement=True,
                                      inverse=trf.inv)
            trf.dat = trf.dat[0]  # drop batch dimension
            trf.inv = False
            trf.order = 1
        elif isinstance(trf, struct.Displacement):
            f = io.volumes.map(trf.file)
            trf.affine = f.affine
            trf.shape = squeeze_to_nd(f.shape, 3, 1)
            trf.dat = f.fdata(**backend).reshape(trf.shape)
            trf.shape = trf.shape[:3]
            if trf.unit == 'mm':
                # convert mm displacement to vox displacement
                trf.dat = spatial.affine_lmdiv(trf.affine, trf.dat[..., None])
                trf.dat = trf.dat[..., 0]
                trf.unit = 'vox'

    # 2) If the first transform is linear, compose it with the input
    #    orientation matrix
    if (options.transformations
            and isinstance(options.transformations[0], struct.Linear)):
        trf = options.transformations[0]
        for file in options.files:
            mat = file.affine.to(**backend)
            aff = trf.affine.to(**backend)
            file.affine = spatial.affine_lmdiv(aff, mat)
        options.transformations = options.transformations[1:]

    def build_from_target(affine, shape):
        """Compose all transformations, starting from the final orientation"""
        grid = spatial.affine_grid(affine.to(**backend), shape)
        for trf in reversed(options.transformations):
            if isinstance(trf, struct.Linear):
                grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
            else:
                mat = trf.affine.to(**backend)
                if trf.inv:
                    vx0 = spatial.voxel_size(mat)
                    vx1 = spatial.voxel_size(affine.to(**backend))
                    factor = vx0 / vx1
                    disp, mat = spatial.resize_grid(trf.dat[None],
                                                    factor,
                                                    affine=mat,
                                                    interpolation=trf.order)
                    disp = spatial.grid_inv(disp[0], type='disp')
                    order = 1
                else:
                    disp = trf.dat
                    order = trf.order
                imat = spatial.affine_inv(mat)
                grid = spatial.affine_matvec(imat, grid)
                grid += helpers.pull_grid(disp, grid, interpolation=order)
                grid = spatial.affine_matvec(mat, grid)
        return grid

    # 3) If target is provided, we can build most of the transform once
    #    and just multiply it with a input-wise affine matrix.
    if options.target:
        grid = build_from_target(options.target.affine, options.target.shape)
        oaffine = options.target.affine

    # 4) Loop across input files
    opt_pull0 = dict(interpolation=options.interpolation,
                     bound=options.bound,
                     extrapolate=options.extrapolate)
    opt_coeff = dict(interpolation=options.interpolation,
                     bound=options.bound,
                     dim=3,
                     inplace=True)
    output = py.make_list(options.output, len(options.files))
    for file, ofname in zip(options.files, output):
        is_label = isinstance(options.interpolation,
                              str) and options.interpolation == 'l'
        ofname = ofname.format(dir=file.dir, base=file.base, ext=file.ext)
        print(f'Reslicing:   {file.fname}\n' f'          -> {ofname}')
        if is_label:
            backend_int = dict(dtype=torch.long, device=backend['device'])
            dat = io.volumes.load(file.fname, **backend_int)
            opt_pull = dict(opt_pull0)
            opt_pull['interpolation'] = 1
        else:
            dat = io.volumes.loadf(file.fname,
                                   rand=options.interpolation > 0,
                                   **backend)
            opt_pull = opt_pull0
        dat = dat.reshape([*file.shape, file.channels])
        dat = utils.movedim(dat, -1, 0)

        if not options.target:
            oaffine = file.affine
            oshape = file.shape
            if options.voxel_size:
                ovx = utils.make_vector(options.voxel_size,
                                        3,
                                        dtype=oaffine.dtype)
                factor = spatial.voxel_size(oaffine) / ovx
                oaffine, oshape = spatial.affine_resize(oaffine,
                                                        oshape,
                                                        factor=factor,
                                                        anchor='f')
            grid = build_from_target(oaffine, oshape)
        mat = file.affine.to(**backend)
        imat = spatial.affine_inv(mat)
        if options.prefilter and not is_label:
            dat = spatial.spline_coeff_nd(dat, **opt_coeff)
        dat = helpers.pull(dat, spatial.affine_matvec(imat, grid), **opt_pull)
        dat = utils.movedim(dat, 0, -1)

        if is_label:
            io.volumes.save(dat, ofname, like=file.fname, affine=oaffine)
        else:
            io.volumes.savef(dat, ofname, like=file.fname, affine=oaffine)