예제 #1
0
 def upsample(self, factor=2, **kwargs):
     kwargs.setdefault('interpolation', 1)
     kwargs.setdefault('bound', 'dft')
     kwargs.setdefault('anchor', 'c')
     dat, aff = spatial.resize_grid(self.dat, factor, type='displacement',
                                    affine=self.affine, **kwargs)
     return type(self)(dat, aff, dim=self.dim)
예제 #2
0
 def build_from_target(target):
     """Compose all transformations, starting from the final orientation"""
     grid = spatial.affine_grid(target.affine.to(**backend), target.shape)
     for trf in reversed(options.transformations):
         if isinstance(trf, Linear):
             grid = spatial.affine_matvec(trf.affine.to(**backend), grid)
         else:
             mat = trf.affine.to(**backend)
             if trf.inv:
                 vx0 = spatial.voxel_size(mat)
                 vx1 = spatial.voxel_size(target.affine.to(**backend))
                 factor = vx0 / vx1
                 disp, mat = spatial.resize_grid(trf.dat[None],
                                                 factor,
                                                 affine=mat,
                                                 interpolation=trf.spline)
                 disp = spatial.grid_inv(disp[0], type='disp')
                 order = 1
             else:
                 disp = trf.dat
                 order = trf.spline
             imat = spatial.affine_inv(mat)
             grid = spatial.affine_matvec(imat, grid)
             grid += helpers.pull_grid(disp, grid, interpolation=order)
             grid = spatial.affine_matvec(mat, grid)
     return grid
예제 #3
0
 def upsample_(self, factor=2, **kwargs):
     kwargs.setdefault('interpolation', 1)
     kwargs.setdefault('bound', 'dft')
     kwargs.setdefault('anchor', 'c')
     self.dat, self.affine = spatial.resize_grid(
         self.dat, factor, type='disp', affine=self.affine, **kwargs)
     return self
예제 #4
0
    def forward(self, grid, affine=None, **overload):
        """

        Parameters
        ----------
        grid : (batch, *spatial_in, ndim) tensor
            Input grid to deform
        affine : (batch, ndim[+1], ndim+1), optional
            Orientation matrix of the input image.
            If provided, the orientation matrix of the resized image is
            returned as well.
        overload : dict
            All parameters defined at build time can be overridden
            at call time.

        Returns
        -------
        resized : (batch, *spatial_out, ndim) tensor
            Resized image.
        affine : (batch, ndim[+1], ndim+1) tensor, optional
            Orientation matrix

        """
        kwargs = {
            'factor': overload.get('factor', self.factor),
            'shape': overload.get('shape', self.shape),
            'type': overload.get('type', self.type),
            'anchor': overload.get('anchor', self.anchor),
            'interpolation': overload.get('interpolation', self.interpolation),
            'bound': overload.get('bound', self.bound),
            'extrapolate': overload.get('extrapolate', self.extrapolate),
        }
        return spatial.resize_grid(grid, affine=affine, **kwargs)
예제 #5
0
파일: spatial.py 프로젝트: balbasty/nitorch
    def forward(self, grid, affine=None, output_shape=None, **overload):
        """

        Parameters
        ----------
        grid : (batch, *spatial_in, ndim) tensor
            Input grid to deform
        affine : (batch, ndim[+1], ndim+1), optional
            Orientation matrix of the input image.
            If provided, the orientation matrix of the resized image is
            returned as well.
        output_shape : bool, optional

        Returns
        -------
        resized : (batch, *spatial_out, ndim) tensor
            Resized image.
        affine : (batch, ndim[+1], ndim+1) tensor, optional
            Orientation matrix

        """
        output_shape = output_shape or self.shape
        kwargs = {
            'factor': overload.get('factor', self.factor),
            'shape': output_shape,
            'type': self.type,
            'anchor': self.anchor,
            'interpolation': self.interpolation,
            'bound': self.bound,
            'extrapolate': self.extrapolate,
            'prefilter': self.prefilter
        }
        return spatial.resize_grid(grid, affine=affine, **kwargs)
예제 #6
0
 def exp(prm):
     disp = spatial.resize_grid(prm,
                                type='displacement',
                                shape=target.shape[2:],
                                interpolation=3,
                                bound='dft')
     grid = disp + spatial.identity_grid(target.shape[2:], **backend)
     return disp, grid
예제 #7
0
    def forward(self,
                source,
                target,
                source_seg=None,
                target_seg=None,
                *,
                _loss=None,
                _metric=None):

        # sanity checks
        check.dim(self.dim, source, target, source_seg, target_seg)
        check.shape(target, source, dims=[0], broadcast_ok=True)
        check.shape(target, source, dims=range(2, self.dim + 2))
        check.shape(target_seg, source_seg, dims=[0], broadcast_ok=True)
        check.shape(target_seg, source_seg, dims=range(2, self.dim + 2))

        # chain operations
        source_and_target = torch.cat((source, target), dim=1)
        affine_prm = self.cnn(source_and_target)
        affine_prm = affine_prm.reshape(affine_prm.shape[:2])
        affine = []
        for prm in affine_prm:
            affine.append(self.exp(prm))
        affine = torch.stack(affine, dim=0)
        grid = self.grid(affine, shape=target.shape[2:])
        deformed_source = self.pull(source, grid)
        if source_seg is not None:
            if source_seg.shape[2:] != source.shape[2:]:
                grid = spatial.resize_grid(grid, shape=source_seg.shape[2:])
            deformed_source_seg = self.pull(source_seg, grid)
        else:
            deformed_source_seg = None

        # compute loss and metrics
        self.compute(_loss,
                     _metric,
                     image=[deformed_source, target],
                     affine=[affine_prm],
                     segmentation=[deformed_source_seg, target_seg])

        if source_seg is None:
            return deformed_source, affine_prm
        else:
            return deformed_source, deformed_source_seg, affine_prm
예제 #8
0
def ffd_exp(prm, shape, order=3, bound='dft', returns='disp'):
    """Transform FFD parameters into a displacement or transformation grid.

    Parameters
    ----------
    prm : (..., *spatial, dim)
        FFD parameters
    shape : sequence[int]
        Exponentiated shape
    order : int, default=3
        Spline order
    bound : str, default='dft'
        Boundary condition
    returns : {'disp', 'grid', 'disp+grid'}, default='grid'
        What to return:
        - 'disp' -> displacement grid
        - 'grid' -> transformation grid

    Returns
    -------
    disp : (..., *shape, dim), optional
        Displacement grid
    grid : (..., *shape, dim), optional
        Transformation grid

    """
    backend = dict(dtype=prm.dtype, device=prm.device)
    dim = prm.shape[-1]
    batch = prm.shape[:-(dim + 1)]
    prm = prm.reshape([-1, *prm.shape[-(dim + 1):]])
    disp = resize_grid(prm,
                       type='displacement',
                       shape=shape,
                       interpolation=order,
                       bound=bound)
    disp = disp.reshape(batch + disp.shape[1:])
    grid = disp + identity_grid(shape, **backend)
    if 'disp' in returns and 'grid' in returns:
        return disp, grid
    elif 'disp' in returns:
        return disp
    elif 'grid' in returns:
        return grid
예제 #9
0
파일: struct.py 프로젝트: balbasty/nitorch
 def free(self):
     """Free the next batch/ladder of parameters"""
     if not self.freeable():
         return
     print('Free nonlin')
     if not hasattr(self, 'optdat'):
         self.optdat = torch.nn.Parameter(self.dat, requires_grad=True)
         self.dat = self.optdat
     else:
         *self.pyramid, pre_level = self.pyramid
         self.dat = self.dat.detach()
         factor = pre_level - self.pyramid[-1]
         new_shape = [s * (2**factor) for s in self.dat.shape[:-1]]
         self.dat, self.affine = spatial.resize_grid(
             self.dat[None],
             shape=new_shape,
             type='displacement',
             affine=self.affine[None])
         self.dat = self.dat[0]
         self.affine = self.affine[0]
         self.optdat = torch.nn.Parameter(self.dat, requires_grad=True)
         self.dat = self.optdat
예제 #10
0
파일: main.py 프로젝트: balbasty/nitorch
def write_data(files, options):

    device = options.gpu if isinstance(options.gpu,
                                       str) else f'cuda:{options.gpu}'
    backend = dict(dtype=torch.float32, device=device)
    ofiles = py.make_list(options.output, len(files))

    for file, ofile in zip(files, ofiles):

        ofile = ofile.format(dir=file.dir,
                             base=file.base,
                             ext=file.ext,
                             sep=os.sep)
        print(f'Resizing:   {file.fname}\n' f'         -> {ofile}')

        dat = io.loadf(file.fname, **backend)
        dat = dat.reshape([*file.shape, file.channels])

        # compute resizing factor
        input_vx = spatial.voxel_size(file.affine)
        if options.voxel_size:
            if options.factor:
                raise ValueError('Cannot use both factor and voxel size')
            factor = input_vx / utils.make_vector(options.voxel_size, 3)
        elif options.factor:
            factor = utils.make_vector(options.factor, 3)
        elif options.shape:
            input_shape = utils.make_vector(dat.shape[:-1],
                                            3,
                                            dtype=torch.float32)
            output_shape = utils.make_vector(options.shape,
                                             3,
                                             dtype=torch.float32)
            factor = output_shape / input_shape
        else:
            raise ValueError('Need at least one of factor/voxel_size/shape')
        factor = factor.tolist()

        # check if output shape is provided
        if options.shape:
            output_shape = py.ensure_list(options.shape, 3)
        else:
            output_shape = None

        # Perform resize
        opt = dict(
            anchor=options.anchor,
            bound=options.bound,
            interpolation=options.interpolation,
            prefilter=options.prefilter,
        )
        if options.grid:
            dat, affine = spatial.resize_grid(dat[None],
                                              factor,
                                              output_shape,
                                              type=options.grid,
                                              affine=file.affine,
                                              **opt)[0]
        else:
            dat = utils.movedim(dat, -1, 0)
            dat, affine = spatial.resize(dat[None],
                                         factor,
                                         output_shape,
                                         affine=file.affine,
                                         **opt)
            dat = utils.movedim(dat[0], 0, -1)

        # Write output file
        io.volumes.savef(dat, ofile, like=file.fname, affine=affine)
예제 #11
0
    def forward(self,
                source,
                target,
                source_seg=None,
                target_seg=None,
                *,
                _loss=None,
                _metric=None):
        """

        Parameters
        ----------
        source : tensor (batch, channel, *spatial)
            Source/moving image
        target : tensor (batch, channel, *spatial)
            Target/fixed image

        _loss : dict, optional
            If provided, all registered losses are computed and appended.
        _metric : dict, optional
            If provided, all registered metrics are computed and appended.

        Returns
        -------
        deformed_source : tensor (batch, channel, *spatial)
            Deformed source image
        affine_prm : tensor (batch,, *spatial, len(spatial))
            affine Lie parameters

        """
        # sanity checks
        check.dim(self.dim, source, target, source_seg, target_seg)
        check.shape(target, source, dims=[0], broadcast_ok=True)
        check.shape(target, source, dims=range(2, self.dim + 2))
        check.shape(target_seg, source_seg, dims=[0], broadcast_ok=True)
        check.shape(target_seg, source_seg, dims=range(2, self.dim + 2))

        # chain operations
        source_and_target = torch.cat((source, target), dim=1)

        # generate affine
        affine_prm = self.cnn(source_and_target)
        affine_prm = affine_prm.reshape(affine_prm.shape[:2])

        # generate velocity
        velocity = self.unet(source_and_target)
        velocity = channel2last(velocity)

        # generate deformation grid
        grid = self.exp(velocity, affine_prm)

        # deform
        deformed_source = self.pull(source, grid)
        if source_seg is not None:
            if source_seg.shape[2:] != source.shape[2:]:
                grid = spatial.resize_grid(grid, shape=source_seg.shape[2:])
            deformed_source_seg = self.pull(source_seg, grid)
        else:
            deformed_source_seg = None

        # compute loss and metrics
        self.compute(_loss,
                     _metric,
                     image=[deformed_source, target],
                     velocity=[velocity],
                     segmentation=[deformed_source_seg, target_seg],
                     affine=[affine_prm])

        if deformed_source_seg is None:
            return deformed_source, velocity, affine_prm
        else:
            return deformed_source, deformed_source_seg, velocity, affine_prm
예제 #12
0
    def forward(self, source, target, source_seg=None, target_seg=None,
                *, _loss=None, _metric=None):
        """

        Parameters
        ----------
        source : tensor (batch, channel, *spatial)
            Source/moving image
        target : tensor (batch, channel, *spatial)
            Target/fixed image
        source_seg : tensor (batch, classes, *spatial), optional
            Source/moving segmentation
        target_seg : tensor (batch, classes, *spatial), optional
            Target/fixed segmentation

        Other Parameters
        ----------------
        _loss : dict, optional
            If provided, all registered losses are computed and appended.
        _metric : dict, optional
            If provided, all registered metrics are computed and appended.

        Returns
        -------
        deformed_source : tensor (batch, channel, *spatial)
            Deformed source image
        deformed_source_seg : tensor (batch, classes, *spatial), optional
            Deformed source segmentation
        velocity : tensor (batch,, *spatial, len(spatial))
            Velocity field

        """
        # sanity checks
        check.dim(self.dim, source, target)
        check.shape(target, source, dims=[0], broadcast_ok=True)
        check.shape(target, source, dims=range(2, self.dim+2))
        check.shape(target_seg, source_seg, dims=[0], broadcast_ok=True)
        check.shape(target_seg, source_seg, dims=range(2, self.dim+2))

        # chain operations
        source_and_target = torch.cat((source, target), dim=1)
        velocity = self.unet(source_and_target)
        velocity = core.utils.channel2last(velocity)
        grid = self.exp(velocity)
        deformed_source = self.pull(source, grid)

        if source_seg is not None:
            if source_seg.shape[2:] != source.shape[2:]:
                grid = spatial.resize_grid(grid, shape=source_seg.shape[2:])
            deformed_source_seg = self.pull(source_seg, grid)
        else:
            deformed_source_seg = None

        # compute loss and metrics
        self.compute(_loss, _metric,
                     image=[deformed_source, target],
                     velocity=[velocity],
                     segmentation=[deformed_source_seg, target_seg])

        if source_seg is None:
            return deformed_source, velocity
        else:
            return deformed_source, deformed_source_seg, velocity