Exemplo n.º 1
0
    def forward(self, x, noise=None, return_resolution=False):

        if noise is not None:
            noise = noise.expand(x.shape)

        dim = x.dim() - 2
        backend = utils.backend(x)
        resolution_exp = utils.make_vector(self.resolution_exp, x.shape[1],
                                           **backend)
        resolution_scale = utils.make_vector(self.resolution_scale, x.shape[1],
                                             **backend)

        all_resolutions = []
        out = torch.empty_like(x)
        for b in range(len(x)):
            for c in range(x.shape[1]):
                resolution = self.resolution(resolution_exp[c],
                                             resolution_scale[c]).sample()
                resolution = resolution.clamp_min(1)
                fwhm = [resolution] * dim
                y = smooth(x[b, c], fwhm=fwhm, dim=dim, padding='same', bound='dct2')
                if noise is not None:
                    y += noise[b, c]
                factor = [1/resolution] * dim
                y = y[None, None]  # need batch and channel for resize
                y = resize(y, factor=factor, anchor='f')
                factor = [resolution] * dim
                all_resolutions.append(factor)
                y = resize(y, factor=factor, shape=x.shape[2:], anchor='f')
                out[b, c] = y[0, 0]

        all_resolutions = utils.as_tensor(all_resolutions, **backend)
        return (out, all_resolutions) if return_resolution else out
Exemplo n.º 2
0
def _resize(maps, rls, aff, shape):
    for map in maps:
        map.volume = spatial.resize(map.volume[None, None, ...],
                                    shape=shape)[0, 0]
        map.affine = aff
    maps.affine = aff
    if rls is not None:
        if rls.dim() == len(shape):
            rls = spatial.resize(rls[None, None], shape=shape)[0, 0]
        else:
            rls = spatial.resize(rls[None], shape=shape)[0]
    return maps, rls
Exemplo n.º 3
0
def _resize(maps, rls, aff, shape):
    """Resize (prolong) current maps to a target resolution"""
    maps.volume = spatial.resize(maps.volume[None], shape=shape)[0]
    maps.affine = aff
    # for map in maps:
    #     map.volume = spatial.resize(map.volume[None, None, ...],
    #                                 shape=shape)[0, 0]
    #     map.affine = aff
    # maps.affine = aff
    if rls is not None:
        if rls.dim() == len(shape):
            rls = spatial.resize(rls[None, None], shape=shape)[0, 0]
        else:
            rls = spatial.resize(rls[None], shape=shape)[0]
    return maps, rls
Exemplo n.º 4
0
    def forward(self, image, affine=None, **overload):
        """

        Parameters
        ----------
        image : (batch, channel, *spatial_in) tensor
            Input image to deform
        affine : (batch, ndim[+1], ndim+1), optional
            Orientation matrix of the input image.
            If provided, the orientation matrix of the resized image is
            returned as well.
        overload : dict
            All parameters defined at build time can be overridden
            at call time.

        Returns
        -------
        resized : (batch, channel, ...) tensor
            Resized image.
        affine : (batch, ndim[+1], ndim+1) tensor, optional
            Orientation matrix

        """
        kwargs = {
            'factor': overload.get('factor', self.factor),
            'shape': overload.get('shape', self.shape),
            'anchor': overload.get('anchor', self.anchor),
            'interpolation': overload.get('interpolation', self.interpolation),
            'bound': overload.get('bound', self.bound),
            'extrapolate': overload.get('extrapolate', self.extrapolate),
        }
        return spatial.resize(image, affine=affine, **kwargs)
Exemplo n.º 5
0
    def forward(self, image, affine=None, output_shape=None):
        """

        Parameters
        ----------
        image : (batch, channel, *spatial_in) tensor
            Input image to deform
        affine : (batch, ndim[+1], ndim+1), optional
            Orientation matrix of the input image.
            If provided, the orientation matrix of the resized image is
            returned as well.
        output_shape : sequence[int], optional

        Returns
        -------
        resized : (batch, channel, ...) tensor
            Resized image.
        affine : (batch, ndim[+1], ndim+1) tensor, optional
            Orientation matrix

        """
        outshape = self.shape(image, output_shape=output_shape)
        kwargs = {
            'shape': outshape[2:],
            'factor': self.factor,
            'anchor': self.anchor,
            'interpolation': self.interpolation,
            'bound': self.bound,
            'extrapolate': self.extrapolate,
            'prefilter': self.prefilter,
        }
        return spatial.resize(image, affine=affine, **kwargs)
Exemplo n.º 6
0
    def resize(self):
        affine, shape = spatial.affine_resize(self.affine0, self.shape0,
                                              1 / (2**(self.level - 1)))

        scl0 = spatial.voxel_size(self.affine0).prod()
        scl = spatial.voxel_size(affine).prod() / scl0
        self.lam_scale = scl

        for map in self.maps:
            map.volume = spatial.resize(map.volume[None, None, ...],
                                        shape=shape)[0, 0]
            map.affine = affine
        self.maps.affine = affine
        if self.rls is not None:
            if self.rls.dim() == len(shape):
                self.rls = spatial.resize(self.rls[None, None], hape=shape)[0,
                                                                            0]
            else:
                self.rls = spatial.resize(self.rls[None], shape=shape)[0]
            self.nll['rls'] = self.rls.reciprocal().sum(dtype=torch.double)
Exemplo n.º 7
0
    def forward(self, x, affine=None):
        """

        Parameters
        ----------
        x : (X, Y, Z) tensor or str
        affine : (4, 4) tensor, optional

        Returns
        -------
        seg : (32, oX, oY, oZ) tensor
            Segmentation
        resliced : (oX, oY, oZ) tensor
            Input resliced to 1 mm RAS
        affine : (4, 4) tensor
            Output orientation matrix

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        if isinstance(x, str):
            x = io.map(x)
        if isinstance(x, io.MappedArray):
            if affine is None:
                affine = x.affine
                x = x.fdata()
                x = x.reshape(x.shape[:3])
        x = SynthPreproc.addnoise(x)
        if affine is not None:
            affine, x = spatial.affine_reorient(affine, x, 'RAS')
            vx = spatial.voxel_size(affine)
            fwhm = 0.25 * vx.reciprocal()
            fwhm[vx > 1] = 0
            x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
            x, affine = spatial.resize(x[None, None],
                                       vx.tolist(),
                                       affine=affine)
            x = x[0, 0]
        oshape = x.shape
        x, crop = SynthPreproc.crop(x)
        x = SynthPreproc.preproc(x)[None, None]
        if self.verbose:
            print('done.', flush=True)
            print('Segmenting... ', end='', flush=True)
        s, x = super().forward(x)[0], x[0, 0]
        if self.verbose:
            print('done.', flush=True)
            print('Postprocessing... ', end='', flush=True)
        s = self.relabel(s.argmax(0))
        x = SynthPreproc.pad(x, oshape, crop)
        s = SynthPreproc.pad(s, oshape, crop)
        if self.verbose:
            print('done.', flush=True)
        return s, x, affine
Exemplo n.º 8
0
 def resize(cls, x, affine, target_vx=1):
     target_vx = utils.make_vector(target_vx, x.dim(),
                                   **utils.backend(affine))
     vx = spatial.voxel_size(affine)
     factor = vx / target_vx
     fwhm = 0.25 * factor.reciprocal()
     fwhm[factor > 1] = 0
     x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
     x, affine = spatial.resize(x[None, None],
                                factor.tolist(),
                                affine=affine)
     x = x[0, 0]
     return x, affine
Exemplo n.º 9
0
 def load(self, x, affine=None):
     if isinstance(x, str):
         x = io.map(x)
     if isinstance(x, io.MappedArray):
         if affine is None:
             affine = x.affine
             x = x.fdata()
             x = x.reshape(x.shape[:3])
     affine_original = affine
     x_original = x.shape
     if affine is not None:
         affine, x = spatial.affine_reorient(affine, x, 'RAS')
         vx = spatial.voxel_size(affine)
         x, affine = spatial.resize(x[None, None],
                                    vx.tolist(),
                                    affine=affine)
         x = x[0, 0]
     return x, affine, x_original, affine_original
Exemplo n.º 10
0
    def forward(self, batch=1, **overload):
        """

        Parameters
        ----------
        batch : int, default=1
            Batch size

        Other Parameters
        ----------------
        shape : sequence[int], optional
        channel : int, optional
        device : torch.device, optional
        dtype : torch.dtype, optional

        Returns
        -------
        field : (batch, channel, *shape) tensor
            Generated random field

        """

        # get arguments
        shape = overload.get('shape', self.shape)
        channel = overload.get('channel', self.channel)
        dtype = overload.get('dtype', self.dtype)
        device = overload.get('device', self.device)
        backend = dict(dtype=dtype, device=device)

        # device/dtype
        nb_dim = len(shape)
        mean = utils.make_vector(self.mean, channel, **backend)
        amplitude = utils.make_vector(self.amplitude, channel, **backend)
        fwhm = utils.make_vector(self.fwhm, nb_dim, **backend)

        # sample spline coefficients
        nodes = [(s/f).ceil().int().item()
                 for s, f in zip(shape, fwhm)]
        sample = torch.randn([batch, channel, *nodes], **backend)
        sample *= utils.unsqueeze(amplitude, -1, nb_dim)
        sample = spatial.resize(sample, shape=shape, interpolation=self.basis,
                                bound='dct2', prefilter=False)
        sample += utils.unsqueeze(mean, -1, nb_dim)
        return sample
Exemplo n.º 11
0
def write_data(files, options):

    device = options.gpu if isinstance(options.gpu,
                                       str) else f'cuda:{options.gpu}'
    backend = dict(dtype=torch.float32, device=device)
    ofiles = py.make_list(options.output, len(files))

    for file, ofile in zip(files, ofiles):

        ofile = ofile.format(dir=file.dir,
                             base=file.base,
                             ext=file.ext,
                             sep=os.sep)
        print(f'Resizing:   {file.fname}\n' f'         -> {ofile}')

        dat = io.loadf(file.fname, **backend)
        dat = dat.reshape([*file.shape, file.channels])

        # compute resizing factor
        input_vx = spatial.voxel_size(file.affine)
        if options.voxel_size:
            if options.factor:
                raise ValueError('Cannot use both factor and voxel size')
            factor = input_vx / utils.make_vector(options.voxel_size, 3)
        elif options.factor:
            factor = utils.make_vector(options.factor, 3)
        elif options.shape:
            input_shape = utils.make_vector(dat.shape[:-1],
                                            3,
                                            dtype=torch.float32)
            output_shape = utils.make_vector(options.shape,
                                             3,
                                             dtype=torch.float32)
            factor = output_shape / input_shape
        else:
            raise ValueError('Need at least one of factor/voxel_size/shape')
        factor = factor.tolist()

        # check if output shape is provided
        if options.shape:
            output_shape = py.ensure_list(options.shape, 3)
        else:
            output_shape = None

        # Perform resize
        opt = dict(
            anchor=options.anchor,
            bound=options.bound,
            interpolation=options.interpolation,
            prefilter=options.prefilter,
        )
        if options.grid:
            dat, affine = spatial.resize_grid(dat[None],
                                              factor,
                                              output_shape,
                                              type=options.grid,
                                              affine=file.affine,
                                              **opt)[0]
        else:
            dat = utils.movedim(dat, -1, 0)
            dat, affine = spatial.resize(dat[None],
                                         factor,
                                         output_shape,
                                         affine=file.affine,
                                         **opt)
            dat = utils.movedim(dat[0], 0, -1)

        # Write output file
        io.volumes.savef(dat, ofile, like=file.fname, affine=affine)
Exemplo n.º 12
0
def resize(maps, rls, aff, shape):
    maps.volume = spatial.resize(maps.volume, shape=shape)
    maps.affine = aff
    if rls is not None:
        rls = spatial.resize(rls, shape=shape)
    return maps, rls
Exemplo n.º 13
0
def correct_smooth(x,
                   sigma=None,
                   lam=10,
                   gamma=10,
                   downsample=None,
                   max_iter=16,
                   max_rls=8,
                   tol=1e-6,
                   verbose=False,
                   device=None):
    """Correct the intensity non-uniformity in a SPIM image.

    The signal is modelled as: f = exp(s + b) + eps, with a penalty on
    the (Squared) gradients of s and on the (squared) curvature of b.

    Parameters
    ----------
    x : tensor
        SPIM image with the z dimension last and the z=0 plane first
    sigma : float, optional
        Noise standard deviation. Default: educated guess.
    lam : float, default=10
        Regularisation on the signal.
    gamma : float, default=10
        Regularisation on the bias field.
    max_iter : int, default=16
        Maximum number of Newton iterations.
    max_rls : int, default=8
        Maximum number of reweighting iterations.
        If 1, this is effectively an l2 regularisation.
    tol : float, default=1e-6
        Tolerance for early stopping.
    verbose : int or bool, default=False
        Verbosity level
    device : torch.device, default=x.device
        Use this device during fitting.

    Returns
    -------
    y : tensor
        Fitted image
    bias : float
        Fitted bias
    x : float
        Corrected image

    """

    x = torch.as_tensor(x)
    if not x.dtype.is_floating_point:
        x = x.to(dtype=torch.get_default_dtype())
    dim = x.dim()

    # downsampling
    if downsample:
        x0 = x
        downsample = py.make_list(downsample, dim)
        x = spatial.pool(dim, x, downsample)
    shape = x.shape
    x = x.to(device)

    # noise educated guess: assume SNR=5 at z=1/2
    center = tuple(slice(s // 3, 2 * s // 3) for s in shape)
    sigma = sigma or x[center].median() / 5
    lam = lam**2 * sigma**2
    gamma = gamma**2 * sigma**2
    regy = lambda y, w: spatial.regulariser(
        y[None], membrane=lam, dim=dim, weights=w)[0]
    regb = lambda b: spatial.regulariser(b[None], bending=gamma, dim=dim)[0]
    solvey = lambda h, g, w: spatial.solve_field_sym(
        h[None], g[None], membrane=lam, dim=dim, weights=w)[0]
    solveb = lambda h, g: spatial.solve_field_sym(
        h[None], g[None], bending=gamma, dim=dim)[0]

    # init
    l1 = max_rls > 1
    if l1:
        w = torch.ones_like(x)[None]
        llw = w.sum()
        max_rls = 10
    else:
        w = None
        llw = 0
        max_rls = 1
    logb = torch.zeros_like(x)
    logy = x.clamp_min(1e-3).log_()
    y = logy.exp()
    b = logb.exp()
    fit = y * b
    res = fit - x
    llx = res.square().sum()
    lly = (regy(logy, w).mul_(logy)).sum()
    llb = (regb(logb).mul_(logb)).sum()
    ll0 = llx + lly + llb + llw
    ll1 = ll0

    for it_ls in range(max_rls):
        for it in range(max_iter):

            # update bias
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regb(logb)
            logb -= solveb(h, g)
            logb0 = logb.mean()
            logb -= logb0
            logy += logb0

            # update fit / ll
            llb = (regb(logb).mul_(logb)).sum()
            b = torch.exp(logb, out=b)
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x

            # update y
            g = h = fit
            h = (h * res).abs_()
            h.addcmul_(g, g)
            g *= res
            g += regy(logy, w)
            logy -= solvey(h, g, w)

            # update fit / ll
            y = torch.exp(logy, out=y)
            fit = y * b
            res = fit - x
            lly = (regy(logy, w).mul_(logy)).sum()

            # compute objective
            llx = res.square().sum()
            ll = llx + lly + llb + llw
            gain = (ll1 - ll) / ll0
            ll1 = ll
            if verbose:
                end = '\n' if verbose > 1 else '\r'
                pre = f'{it_ls:3d} | ' if l1 else ''
                print(pre + f'{it:3d} | {ll:12.6g} | gain = {gain:12.6g}',
                      end=end)
            if it > 0 and abs(gain) < tol:
                break

        if l1:
            w, llw = spatial.membrane_weights(logy[None],
                                              lam,
                                              dim=dim,
                                              return_sum=True)
            ll0 = ll
    if verbose:
        print('')

    if downsample:
        b = spatial.resize(logb.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        y = spatial.resize(logy.to(x0.device)[None, None],
                           downsample,
                           shape=x0.shape,
                           anchor='f')[0, 0].exp_()
        x = x0
    else:
        y = torch.exp(logy, out=y)
    x = x / b
    return y, b, x