Ejemplo n.º 1
0
def _get_image_data(pths, device=None, dtype=None):
    """Formats image data (on disk) to tensor compatible with the denoise_mri function

    OBS: Assumes that all input images are 3D volumes with the same dimensions.

    Parameters
    ----------
    pths : [nchannels] sequence
        Paths to image data
    device : torch.device, optional
        Torch device
    dtype : torch.dtype, optional
        Torch data type

    Returns
    ----------
    dat : (dmx, dmy, dmz, nchannels) tensor
        Image tensor

    """
    for i, p in enumerate(pths):
        # read nifti
        nii = map(p)
        # get data
        if i == 0:
            dat = nii.fdata(dtype=dtype, device=device, rand=False)[..., None]
        else:
            dat = torch.cat(
                (dat, nii.fdata(dtype=dtype, device=device, rand=False)[...,
                                                                        None]),
                dim=-1)

    return dat
Ejemplo n.º 2
0
    def _map_image(self):
        image = self.image
        self._map = None
        self._fdata = None
        self._affine = None
        self._shape = None
        if isinstance(image, str):
            self._map = io.map(image)
        else:
            self._map = None

        if self._map is None:
            if not isinstance(image, (list, tuple)):
                image = [image]
            if len(image) < 2:
                image = [*image, None, None]
            dat, aff, *_ = image
            dat = torch.as_tensor(dat)
            if aff is None:
                aff = spatial.affine_default(dat.shape[-3:])
            self._fdata = dat
            self._affine = aff
            self._shape = tuple(dat.shape[-3:])
        else:
            self._affine = self._map.affine
            self._shape = tuple(self._map.shape[-3:])
Ejemplo n.º 3
0
def get_spm_prior(**backend):
    fname = path_spm_prior()
    f = io.map(fname).movedim(-1, 0)[:-1]  # drop background
    aff = f.affine
    dat = f.fdata(**backend)
    aff = aff.to(**utils.backend(dat))
    return dat, aff
Ejemplo n.º 4
0
def _map_image(fnames, dim=None):
    """
    Load a N-D image from disk.
    Returns:
        image : (C, *spatial) MappedTensor
        affine: (D+1, D+1) tensor
    """
    affine = None
    imgs = []
    for fname in fnames:
        img = io.map(fname)
        if affine is None:
            affine = img.affine
        if dim is None:
            dim = img.affine.shape[-1] - 1
        # img = img.fdata(rand=True, device=device)
        if img.dim > dim:
            img = img.movedim(-1, 0)
        else:
            img = img[None]
        img = img.unsqueeze(-1, dim + 1 - img.dim)
        if img.dim > dim + 1:
            raise ValueError(f'Don\'t know how to deal with an image of '
                             f'shape {tuple(img.shape)}')
        imgs.append(img)
        del img
    imgs = io.cat(imgs, dim=0)
    return imgs, affine
Ejemplo n.º 5
0
def _format_input(img, device='cpu', rand=False, cutoff=None):
    """Format preprocessing input data.
    """
    if isinstance(img, str):
        img = [img]
    if isinstance(img, list) and isinstance(img[0], torch.Tensor):
        img = [img]
    file = []
    dat = []
    mat = []
    for n in range(len(img)):
        if isinstance(img[n], str):
            # Input are nitorch.io compatible paths
            file.append(map(img[n]))
            dat.append(file[n].fdata(dtype=torch.float32,
                                     device=device,
                                     rand=rand,
                                     cutoff=cutoff))
            mat.append(file[n].affine.to(dtype=torch.float64, device=device))
        else:
            # Input are tensors (clone so to not modify input data)
            file.append(None)
            dat.append(img[n][0].clone().to(dtype=torch.float32,
                                            device=device))
            mat.append(img[n][1].clone().to(dtype=torch.float64,
                                            device=device))

    return dat, mat, file
Ejemplo n.º 6
0
 def set_dat(self, dat, affine=None, **backend):
     if isinstance(dat, str):
         dat = io.map(dat)
         if affine is None:
             affine = self.dat.affine
     if isinstance(dat, SpatialTensor) and affine is None:
         affine = dat.affine
     self.dat = Displacement.make(dat, affine=affine, **backend)
     return self
Ejemplo n.º 7
0
def _bb_atlas(name, fov, dtype=torch.float64, device='cpu'):
    """Bounding-box NITorch atlas data to specific field-of-view.

    Parameters
    ----------
    name : str
        Name of nitorch data, available are:
        * atlas_t1: MRI T1w intensity atlas, 1 mm resolution.
        * atlas_t2: MRI T2w intensity atlas, 1 mm resolution.
        * atlas_pd: MRI PDw intensity atlas, 1 mm resolution.
        * atlas_t1_mni: MRI T1w intensity atlas, in MNI space, 1 mm resolution.
        * atlas_t2_mni: MRI T2w intensity atlas, in MNI space, 1 mm resolution.
        * atlas_pd_mni: MRI PDw intensity atlas, in MNI space, 1 mm resolution.
    fov : str
        Field-of-view, specific to 'name':
        * 'atlas_t1' | 'atlas_t2' | 'atlas_pd':
            * 'brain': Head FOV.
            * 'head': Brain FOV.

    Returns
    ----------
    mat_mu : (4, 4) tensor, dtype=float64
        Output affine matrix.
    dim_mu : (3, ) tensor, dtype=float64
        Output dimensions.

    """
    # Get atlas information
    file_mu = map(fetch_data(name))
    dim_mu = file_mu.shape
    mat_mu = file_mu.affine.type(torch.float64).to(device)
    # Get bounding box
    o = [[0, 0, 0], [0, 0, 0]]
    if name in ['atlas_t1', 'atlas_t2', 'atlas_pd']:
        if fov == 'brain' or fov == 'head':
            o[0][0] = 18
            o[0][1] = 52
            o[0][2] = 120
            o[1][0] = 18
            o[1][1] = 48
            o[1][2] = 58
            if fov == 'head':
                o[0][2] = 25
    # Get bounding box
    bb = torch.tensor(
        [[1 + o[0][0], 1 + o[0][1], 1 + o[0][2]],
         [dim_mu[0] - o[1][0], dim_mu[1] - o[1][1], dim_mu[2] - o[1][2]]])
    bb = bb.type(torch.float64).to(device)
    # Output dimensions
    dim_mu = bb[1, ...] - bb[0, ...] + 1
    # Bounding-box atlas affine
    mat_bb = affine_matrix_classic(bb[0, ...] - 1)
    # Modulate atlas affine with bb affine
    mat_mu = mat_mu.mm(mat_bb)

    return mat_mu, dim_mu
Ejemplo n.º 8
0
def get_spm_prior(**backend):
    url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii'
    fname = os.path.join(cache_dir, 'SPM12_TPM.nii')
    if not os.path.exists(fname):
        os.makedirs(cache_dir, exist_ok=True)
        fname = download(url, fname)
    f = io.map(fname).movedim(-1, 0)  #[:-1]  # drop background
    aff = f.affine
    dat = f.fdata(**backend)
    aff = aff.to(**utils.backend(dat))
    return dat, aff
Ejemplo n.º 9
0
    def forward(self, x, affine=None):
        """

        Parameters
        ----------
        x : (X, Y, Z) tensor or str
        affine : (4, 4) tensor, optional

        Returns
        -------
        seg : (32, oX, oY, oZ) tensor
            Segmentation
        resliced : (oX, oY, oZ) tensor
            Input resliced to 1 mm RAS
        affine : (4, 4) tensor
            Output orientation matrix

        """
        if self.verbose:
            print('Preprocessing... ', end='', flush=True)
        if isinstance(x, str):
            x = io.map(x)
        if isinstance(x, io.MappedArray):
            if affine is None:
                affine = x.affine
                x = x.fdata()
                x = x.reshape(x.shape[:3])
        x = SynthPreproc.addnoise(x)
        if affine is not None:
            affine, x = spatial.affine_reorient(affine, x, 'RAS')
            vx = spatial.voxel_size(affine)
            fwhm = 0.25 * vx.reciprocal()
            fwhm[vx > 1] = 0
            x = spatial.smooth(x, fwhm=fwhm.tolist(), dim=3)
            x, affine = spatial.resize(x[None, None],
                                       vx.tolist(),
                                       affine=affine)
            x = x[0, 0]
        oshape = x.shape
        x, crop = SynthPreproc.crop(x)
        x = SynthPreproc.preproc(x)[None, None]
        if self.verbose:
            print('done.', flush=True)
            print('Segmenting... ', end='', flush=True)
        s, x = super().forward(x)[0], x[0, 0]
        if self.verbose:
            print('done.', flush=True)
            print('Postprocessing... ', end='', flush=True)
        s = self.relabel(s.argmax(0))
        x = SynthPreproc.pad(x, oshape, crop)
        s = SynthPreproc.pad(s, oshape, crop)
        if self.verbose:
            print('done.', flush=True)
        return s, x, affine
Ejemplo n.º 10
0
def get_data(x, w, affine, dim, **backend):
    if not torch.is_tensor(x):
        if isinstance(x, str):
            f = io.map(x)
            if affine is None:
                affine = f.affine
            if f.dim > dim:
                if f.shape[dim] == 1:
                    f = f.squeeze(dim)
                if f.dim > dim + 1:
                    raise ValueError('Too many dimensions')
            if f.dim > dim:
                f = f.movedim(-1, 0)
            else:
                f = f[None]
            x = f.fdata(**backend, rand=True, missing=0)
        else:
            f = io.stack([io.map(x1) for x1 in x])
            if affine is None:
                affine = f.affine[0]
            x = f.fdata(**backend, rand=True, missing=0)

    if x.dim() > dim + 1:
        x = x.unsqeeze(-1)
    if x.dim() > dim + 1:
        raise ValueError('Too many dimensions')
    if x.dim() == dim:
        x = x[None]

    if not torch.is_tensor(w) and w is not None:
        w = io.loadf(w, **backend)
        if x.dim() > dim:
            w = w.squeeze(-1)
        if x.dim() > dim:
            raise ValueError('Too many dimensions')

    x = x.contiguous()
    if w is not None:
        w = w.contiguous()
    return x, w, affine
Ejemplo n.º 11
0
def get_prior(prior, affine_prior, **backend):
    if prior is None:
        prior, _affine_prior = get_spm_prior(**backend)
        if affine_prior is None:
            affine_prior = _affine_prior
    elif isinstance(prior, str):
        prior = io.map(prior).movedim(-1, 0)
        if affine_prior is None:
            affine_prior = prior.affine
        prior = prior.fdata(**backend)
    else:
        prior = prior.to(**backend)
    return prior, affine_prior
Ejemplo n.º 12
0
def _read_label(x, pth, sett):
    """Read labels and add to input struct.
    """
    # Load labels
    file = map(pth)
    dat = file.fdata(dtype=torch.float32, device=sett.device)
    # Sanity check
    if not torch.equal(torch.as_tensor(x.dim), torch.as_tensor(dat.shape)):
        raise ValueError('Incorrect label dimensions.')
    # Append labels
    x.label = [dat, file]

    return x
Ejemplo n.º 13
0
 def _init_from_fname(new,
                      fnames,
                      permission='r',
                      keep_open=False,
                      **attributes):
     fnames = py.make_list(fnames)
     fs = []
     for fname in fnames:
         f = io.map(fname, permission=permission, keep_open=keep_open)
         while f.dim < 4:
             f = f.unsqueeze(-1)
         fs += [f]
     fs = io.cat(fs, -1).permute([-1, 0, 1, 2])
     new._init_from_mapped(fs, **attributes)
Ejemplo n.º 14
0
    def from_fname(cls, fname, permission='r', keep_open=False, **attributes):
        """Build an MRI object from a file name.

        We accept paths of the form 'path/to/file.nii,1,2', which
        mean that only the subvolume `[:, :, :, 1, 2]` should be read.
        The first three (spatial) dimensions are always read.
        """
        fname = str(fname)
        fname, *index = fname.split(',')
        mapped = io.map(fname, permission=permission, keep_open=keep_open)
        if index:
            index = tuple(int(i) for i in index)
            index = (slice(None), ) * 3 + index
            mapped = mapped[index]
        return cls.from_mapped(mapped, **attributes)
Ejemplo n.º 15
0
 def prepare_one(inp):
     if isinstance(inp, (list, tuple)):
         has_aff = len(inp) > 1
         if has_aff:
             aff0 = inp[1]
         inp, aff = prepare_one(inp[0])
         if has_aff:
             aff = aff0
         return [inp, aff]
     if isinstance(inp, str):
         inp = io.map(inp)[None, None]
     if isinstance(inp, io.MappedArray):
         return inp.fdata(rand=True), inp.affine[None]
     inp = torch.as_tensor(inp)
     aff = spatial.affine_default(inp.shape)[None]
     return [inp, aff]
Ejemplo n.º 16
0
 def load(self, x, affine=None):
     if isinstance(x, str):
         x = io.map(x)
     if isinstance(x, io.MappedArray):
         if affine is None:
             affine = x.affine
             x = x.fdata()
             x = x.reshape(x.shape[:3])
     affine_original = affine
     x_original = x.shape
     if affine is not None:
         affine, x = spatial.affine_reorient(affine, x, 'RAS')
         vx = spatial.voxel_size(affine)
         x, affine = spatial.resize(x[None, None],
                                    vx.tolist(),
                                    affine=affine)
         x = x[0, 0]
     return x, affine, x_original, affine_original
Ejemplo n.º 17
0
    def do_apply(fnames, phi, jac):
        """Correct files with a given polarity"""
        for fname in fnames:
            dir, base, ext = py.fileparts(fname)
            ofname = options.output
            ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base,
                                   ext=ext)
            if options.verbose:
                print(f'unwarp {fname} \n'
                      f'    -> {ofname}')

            f = io.map(fname)
            d = f.fdata(device=device)
            d = utils.movedim(d, readout, -1)
            d = _deform1d(d, phi)
            if jac is not None:
                d *= jac
            d = utils.movedim(d, -1, readout)

            io.savef(d, ofname, like=fname)
Ejemplo n.º 18
0
 def __init__(self, dat, affine=None, dim=None, **backend):
     """
     Parameters
     ----------
     dat : ([C], *spatial) tensor
     affine : tensor, optional
     dim : int, default=`dat.dim() - 1`
     **backend : dtype, device
     """
     if isinstance(dat, str):
         dat = io.map(dat)[None]
     if isinstance(dat, io.MappedArray):
         if affine is None:
             affine = dat.affine
         dat = dat.fdata(rand=True, **backend)
     self.dim = dim or dat.dim() - 1
     self.dat = dat
     if affine is None:
         affine = spatial.affine_default(self.shape, **utils.backend(dat))
     self.affine = affine.to(utils.backend(self.dat)['device'])
Ejemplo n.º 19
0
def _cli(args):
    """Command-line interface for `smooth` without exception handling"""
    args = args or sys.argv[1:]

    options = parser(args)
    if options.help:
        print(help)
        return

    fwhm = options.fwhm
    unit = 'mm'
    if isinstance(fwhm[-1], str):
        *fwhm, unit = fwhm
    fwhm = make_list(fwhm, 3)

    options.output = make_list(options.output, len(options.files))
    for fname, ofname in zip(options.files, options.output):
        f = io.map(fname)
        vx = voxel_size(f.affine).tolist()
        dim = len(vx)
        if unit == 'mm':
            fwhm1 = [f / v for f, v in zip(fwhm, vx)]
        else:
            fwhm1 = fwhm[:len(vx)]

        dat = f.fdata()
        dat = movedim_front2back(dat, dim)
        dat = smooth(dat,
                     type=options.method,
                     fwhm=fwhm1,
                     basis=options.basis,
                     bound=options.padding,
                     dim=dim)
        dat = movedim_back2front(dat, dim)

        folder, base, ext = fileparts(fname)
        ofname = ofname.format(dir=folder or '.',
                               base=base,
                               ext=ext,
                               sep=os.path.sep)
        io.savef(dat, ofname, like=f)
Ejemplo n.º 20
0
def _get_image_data(pths, device=None, dtype=None):
    """Formats image data (on disk) to tensor compatible with the denoise_mri function

    OBS: Assumes that all input images are 3D volumes with the same dimensions.

    Parameters
    ----------
    pths : [nchannels] sequence
        Paths to image data
    device : torch.device, optional
        Torch device
    dtype : torch.dtype, optional
        Torch data type

    Returns
    ----------
    dat : (nchannels, dmx, dmy, dmz) tensor
        Image tensor
    affine : (4, 4) tensor
        Affine matrix (assumed the same across channels)
    nii : nitorch.BabelArray
        BabelArray object.

    """
    for i, p in enumerate(pths):
        # read nifti
        nii = map(p)
        # get data
        if i == 0:
            dat = nii.fdata(dtype=dtype, device=device, rand=False)[None]
            affine = nii.affine.type(dat.dtype).to(dat.device)
            shape = nii.shape
        else:
            if not torch.equal(torch.as_tensor(shape),
                               torch.as_tensor(nii.shape)):
                raise ValueError('All images must have same dimensions!')
            dat = torch.cat(
                (dat, nii.fdata(dtype=dtype, device=device, rand=False)[None]),
                dim=0)

    return dat, affine, nii
Ejemplo n.º 21
0
def estimate_noise(dat,
                   show_fit=False,
                   fig_num=1,
                   num_class=2,
                   mu_noise=None,
                   max_iter=10000,
                   verbose=0,
                   bins=1024,
                   chi=False):
    """Estimate the noise distribution in an image by fitting either a
    Gaussian, Rician or noncentral Chi mixture model to the image's
    intensity histogram.

    The Gaussian model is only used if negative values are found in the
    image (e.g., if it is a CT scan).

    Parameters
    ----------
    dat : str or tensor
        Tensor or path to nifti file.
    show_fit : bool, default=False
        Show a plot of the histogram fit at the end
    fig_num : int, default=1
        ID of matplotlib figure to use
    num_class : int, default=2
        Number of mixture classes (only for GMM).
    mu_noise : float, optional
        Mean of noise class. If provided, the class with mean value closest
        to `mu_noise` is assumed to be the background class. Otherwise, it is
        the class with smallest standard deviation.
    max_iter : int, default=10000
        Maximum number of EM iterations.
    verbose int, defualt=0:
        Display progress. Defaults to 0.
            * 0: None.
            * 1: Print summary when finished.
            * 2: 1 + Log-likelihood plot.
            * 3: 1 + 2 + print convergence.
    bins : int, default=1024
        Number of histogram bins.
    chi : bool, default=False
        Fit a noncentral Chi rather than a Rice model.

    Returns
    -------
    prm_noise : dict
        Parameters of the distribution of the background (noise) class
        With fields 'sd', 'mean' and (if `chi`) 'dof'
    prm_not_noise : dict
        Parameters of the distribution of the foreground (tissue) class
        With fields 'sd', 'mean' and (if `chi`) 'dof'

    """
    DTYPE = torch.double  # use double for accuracy (maybe single would work?)

    slope = None
    if isinstance(dat, str):
        dat = io.map(dat)
    if isinstance(dat, io.MappedArray):
        slope = dat.slope
        if not slope and not dtype_info(dat.dtype).if_floating_point:
            slope = 1
        dat = dat.fdata(rand=True, missing=0, dtype=DTYPE)
    dat = torch.as_tensor(dat, dtype=DTYPE).flatten()
    device = dat.device
    if not slope and not dat.dtype.is_floating_point:
        slope = 1

    # exclude missing values
    dat = dat[torch.isfinite(dat)]
    dat = dat[dat != 0]

    # Mask and get min/max
    mn = dat.min()
    mx = dat.max()
    dat = dat[dat != mn]
    dat = dat[dat != mx]
    mn = mn.round()
    mx = mx.round()
    if slope:
        # ensure bin width aligns with integer width
        width = (mx - mn) / bins
        width = (width / slope).ceil() * slope
        mx = mn + bins * width

    # Histogram bin data
    dat = torch.histc(dat, bins=bins, min=mn, max=mx).to(DTYPE)
    x = torch.linspace(mn, mx, steps=bins, device=device, dtype=DTYPE)

    # fit mixture model
    if mn < 0:  # Make GMM model
        model = GMM(num_class=num_class)
    elif chi:
        model = CMM(num_class=num_class)
    else:  # Make RMM model
        model = RMM(num_class=num_class)

    # Fit GMM/RMM/CMM using Numpy
    model.fit(x,
              W=dat,
              verbose=verbose,
              max_iter=max_iter,
              show_fit=show_fit,
              fig_num=fig_num)

    # Get means and mixing proportions
    mu, _ = model.get_means_variances()
    mu = mu.squeeze()
    mp = model.mp
    if mn < 0:  # GMM
        sd = torch.sqrt(model.Cov).squeeze()
    else:  # RMM/CMM
        sd = model.sig.squeeze()

    # Get std and mean of noise class
    if mu_noise:
        # Closest to mu_bg
        _, ix_noise = torch.min(torch.abs(mu - mu_noise), dim=0)
    else:
        # With smallest sd
        _, ix_noise = torch.min(sd, dim=0)
    mu_noise = mu[ix_noise]
    sd_noise = sd[ix_noise]
    if chi:
        dof_noise = model.dof[ix_noise]

    # Get std and mean of other classes (means and sds weighted by mps)
    rng = list(range(num_class))
    del rng[ix_noise]
    mu = mu[rng]
    sd = sd[rng]
    w = mp[rng]
    w = w / torch.sum(w)
    mu_not_noise = sum(w * mu)
    sd_not_noise = sum(w * sd)
    if chi:
        dof = model.dof[rng]
        dof_not_noise = sum(w * dof)

    # return dictionaries of parameters
    prm_noise = dict(sd=sd_noise, mean=mu_noise)
    prm_not_noise = dict(sd=sd_not_noise, mean=mu_not_noise)
    if chi:
        prm_noise['dof'] = dof_noise
        prm_not_noise['dof'] = dof_not_noise
    return prm_noise, prm_not_noise
Ejemplo n.º 22
0
    def __init__(self, dat, levels=1, affine=None, dim=None,
                 mask=None, preview=None,
                 bound='dct2', extrapolate=False, method='gauss', **backend):
        """

        Parameters
        ----------
        dat : [list of] (..., *shape) tensor or Image
        levels : int or list[int] or range, default=0
            If an int, it is the number of levels.
            If a range or list, they are the indices of levels to compute.
            `0` is the native resolution, `1` is half of it, etc.
        affine : [list of] tensor, optional
        dim : int, optional
        bound : str, default='dct2'
        extrapolate : bool, default=True
        method : {'gauss', 'average', 'median', 'stride'}, default='gauss'
        """
        # I don't call super().__init__() on purpose
        self.method = method
        
        if isinstance(dat, Image):
            if affine is None:
                affine = dat.affine
            dim = dat.dim
            mask = dat.mask
            preview = dat._preview
            bound = dat.bound
            extrapolate = dat.extrapolate
            dat = dat.dat
        if isinstance(dat, str):
            dat = io.map(dat)
        if isinstance(dat, io.MappedArray):
            if affine is None:
                affine = dat.affine
            dat = dat.fdata(rand=True, **backend)[None]

        if isinstance(levels, int):
            levels = range(levels)
        if isinstance(levels, range):
            levels = list(levels)

        if not isinstance(dat, (list, tuple)):
            dim = dim or dat.dim() - 1
            if not levels:
                raise ValueError('levels required to compute pyramid')
            if isinstance(levels, int):
                levels = range(1, levels+1)
            if affine is None:
                shape = dat.shape[-dim:]
                affine = spatial.affine_default(shape, **utils.backend(dat[0]))
            dat, mask, preview = self._build_pyramid(
                dat, levels, method, dim, bound, mask, preview)
        dat = list(dat)
        if not mask:
            mask = [None] * len(dat)
        if not preview:
            preview = [None] * len(dat)

        if all(isinstance(d, Image) for d in dat):
            self._dat = dat
            return

        dim = dim or (dat[0].dim() - 1)
        if affine is None:
            shape = dat[0].shape[-dim:]
            affine = spatial.affine_default(shape, **utils.backend(dat[0]))
        if not isinstance(affine, (list, tuple)):
            if not levels:
                raise ValueError('levels required to compute affine pyramid')
            shape = dat[0].shape[-dim:]
            affine = self._build_affine_pyramid(affine, shape, levels, method)
        self._dat = [Image(d, aff, dim=dim,  mask=m, preview=p,
                           bound=bound, extrapolate=extrapolate)
                     for d, m, p, aff in zip(dat, mask, preview, affine)]
Ejemplo n.º 23
0
def _read_data(data, sett):
    """ Parse input data into algorithm input struct(s).

    Args:
        data

    Returns:
        x (_input()): Algorithm input struct(s).

    """
    # Sanity check
    mat_vol = sett.mat
    if isinstance(data, str):
        file = map(data)
        dim = file.shape
        if len(dim) > 3:
            # Data is path to 4D nifti
            data = file.fdata()
            mat_vol = file.affine
    try:
        data.shape
        data = data[..., None]
        data = data[:, :, :, :, 0]
        if mat_vol is None:
            raise ValueError('Image data given as array, please also provide affine matrix in sett.mat!')
    except AttributeError:
        pass
    if isinstance(data, str):
        data = [data]

    # Number of channels
    if mat_vol is not None:
        C = data.shape[3]
    else:
        C = len(data)

    x = []
    for c in range(C):  # Loop over channels
        x.append([])
        x[c] = []
        if mat_vol is None and isinstance(data[c], list) and (isinstance(data[c][0], str) or isinstance(data[c][0], list)):
            # Possibly multiple repeats per channel
            for n in range(len(data[c])):  # Loop over observations of channel c
                x[c].append(_input())
                # Get data
                dat, dim, mat, fname, direc, nam, file, ct = \
                    _read_image(data[c][n], sett.device, could_be_ct=sett.ct)
                # Assign
                x[c][n].dat = dat
                x[c][n].dim = dim
                x[c][n].mat = mat
                x[c][n].fname = fname
                x[c][n].direc = direc
                x[c][n].nam = nam
                x[c][n].file = file
                x[c][n].ct = ct
        else:
            # One repeat per channel
            n = 0
            x[c].append(_input())
            # Get data
            if mat_vol is not None:
                dat, dim, mat, fname, direc, nam, file, ct = \
                    _read_image([data[..., c], mat_vol], sett.device, could_be_ct=sett.ct)
            else:
                dat, dim, mat, fname, direc, nam, file, ct = \
                    _read_image(data[c], sett.device, could_be_ct=sett.ct)
            # Assign
            x[c][n].dat = dat
            x[c][n].dim = dim
            x[c][n].mat = mat
            x[c][n].fname = fname
            x[c][n].direc = direc
            x[c][n].nam = nam
            x[c][n].file = file
            x[c][n].ct = ct

    # Add labels (if given)
    if sett.label is not None:
        pth_label = sett.label[0]
        ix_cr = sett.label[1]  # Index channel and repeat
        for c in range(len(x)):
            for n in range(len(x[c])):
                if c == ix_cr[0] and n == ix_cr[1]:
                    x[c][n] = _read_label(x[c][n], pth_label, sett)

    # Print to screen
    _print_info('filenames', sett, x)

    return x
Ejemplo n.º 24
0
def _main(options):
    if isinstance(options.gpu, str):
        device = torch.device(options.gpu)
    else:
        assert isinstance(options.gpu, int)
        device = torch.device(f'cuda:{options.gpu}')
    if not torch.cuda.is_available():
        device = 'cpu'

    # prepare options
    estatics_opt = ESTATICSOptions()
    estatics_opt.likelihood = options.likelihood
    estatics_opt.verbose = options.verbose >= 1
    estatics_opt.plot = options.verbose >= 2
    estatics_opt.recon.space = options.space
    if isinstance(options.space, str) and  options.space != 'mean':
        for c, contrast in enumerate(options.contrast):
            if contrast.name == options.space:
                estatics_opt.recon.space = c
                break
    estatics_opt.backend.device = device
    estatics_opt.optim.nb_levels = options.levels
    estatics_opt.optim.max_iter_rls = options.iter
    estatics_opt.optim.tolerance = options.tol
    estatics_opt.regularization.norm = options.regularization
    estatics_opt.regularization.factor = [*options.lam_intercept, options.lam_decay]
    estatics_opt.distortion.enable = options.meetup
    estatics_opt.distortion.bending = options.lam_meetup
    estatics_opt.preproc.register = options.register

    # prepare files
    contrasts = []
    distortion = []
    for i, c in enumerate(options.contrast):

        # read meta-parameters
        meta = {}
        if c.te:
            te, unit = c.te, ''
            if isinstance(te[-1], str):
                *te, unit = te
            if unit:
                if unit == 'ms':
                    te = [t * 1e-3 for t in te]
                elif unit not in ('s', 'sec'):
                    raise ValueError(f'TE unit: {unit}')
            if c.echo_spacing:
                delta, *unit = c.echo_spacing
                unit = unit[0] if unit else ''
                if unit == 'ms':
                    delta = delta * 1e-3
                elif unit not in ('s', 'sec'):
                    raise ValueError(f'echo spacing unit: {unit}')
                ne = sum(io.map(f).unsqueeze(-1).shape[3] for f in c.echoes)
                te = [te[0] + e*delta for e in range(ne)]
            meta['te'] = te

        # map volumes
        contrasts.append(qio.GradientEchoMulti.from_fname(c.echoes, **meta))

        if c.readout:
            layout = spatial.affine_to_layout(contrasts[-1].affine)
            layout = spatial.volume_layout_to_name(layout)
            readout = None
            for j, l in enumerate(layout):
                if l.lower() in c.readout.lower():
                    readout = j - 3
            contrasts[-1].readout = readout

        if c.b0:
            bw = c.bandwidth
            b0, *unit = c.b0
            unit = unit[-1] if unit else 'vx'
            fb0 = b0.map(b0)
            b0 = fb0.fdata(device=device)
            b0 = spatial.reslice(b0, fb0.affine, contrasts[-1][0].affine,
                                 contrasts[-1][0].shape)
            if unit.lower() == 'hz':
                if not bw:
                    raise ValueError('Bandwidth required to convert fieldmap'
                                     'from Hz to voxel')
                b0 /= bw
            b0 = DenseDistortion(b0)
            distortion.append(b0)
        else:
            distortion.append(None)

    # run algorithm
    [te0, r2s, *b0] = estatics(contrasts, distortion, opt=estatics_opt)

    # write results

    # --- intercepts ---
    odir0 = options.odir
    for i, te1 in enumerate(te0):
        ifname = contrasts[i].echo(0).volume.fname
        odir, obase, oext = py.fileparts(ifname)
        odir = odir0 or odir
        obase = obase + '_TE0'
        ofname = os.path.join(odir, obase + oext)
        io.savef(te1.volume, ofname, affine=te1.affine, like=ifname, te=0, dtype='float32')

    # --- decay ---
    ifname = contrasts[0].echo(0).volume.fname
    odir, obase, oext = py.fileparts(ifname)
    odir = odir0 or odir
    io.savef(r2s.volume, os.path.join(odir, 'R2star' + oext), affine=r2s.affine, dtype='float32')

    # --- fieldmap + undistorted ---
    if b0:
        b0 = b0[0]
        for i, b01 in enumerate(b0):
            ifname = contrasts[i].echo(0).volume.fname
            odir, obase, oext = py.fileparts(ifname)
            odir = odir0 or odir
            obase = obase + '_B0'
            ofname = os.path.join(odir, obase + oext)
            io.savef(b01.volume, ofname, affine=b01.affine, like=ifname, te=0, dtype='float32')
        for i, (c, b) in enumerate(zip(contrasts, b0)):
            readout = c.readout
            grid_up, grid_down, jac_up, jac_down = b.exp2(
                add_identity=True, jacobian=True)
            for j, e in enumerate(c):
                blip = e.blip or (2*(j % 2) - 1)
                grid_blip = grid_down if blip > 0 else grid_up  # inverse of
                jac_blip = jac_down if blip > 0 else jac_up     # forward model
                ifname = e.volume.fname
                odir, obase, oext = py.fileparts(ifname)
                odir = odir0 or odir
                obase = obase + '_unwrapped'
                ofname = os.path.join(odir, obase + oext)
                d = e.fdata(device=device)
                d, _ = pull1d(d, grid_blip, readout)
                d *= jac_blip
                io.savef(d, ofname, affine=e.affine, like=ifname)
                del d
            del grid_up, grid_down, jac_up, jac_down
    if options.register:
        for i, c in enumerate(contrasts):
            for j, e in enumerate(c):
                ifname = e.volume.fname
                odir, obase, oext = py.fileparts(ifname)
                odir = odir0 or odir
                obase = obase + '_registered'
                ofname = os.path.join(odir, obase + oext)
                io.save(e.volume, ofname, affine=e.affine)
Ejemplo n.º 25
0
def main(options):

    # find readout direction
    f = io.map(options.echoes[0])
    affine, shape = f.affine, f.shape
    readout = get_readout(options.direction, affine, shape, options.verbose)

    if not options.reversed:
        reversed_echoes = options.synth
    else:
        reversed_echoes = options.reversed

    # do EPIC
    fit = epic(options.echoes,
               reverse_echoes=reversed_echoes,
               fieldmap=options.fieldmap,
               extrapolate=options.extrapolate,
               bandwidth=options.bandwidth,
               polarity=options.polarity,
               readout=readout,
               slicewise=options.slicewise,
               lam=options.penalty,
               max_iter=options.maxiter,
               tol=options.tolerance,
               verbose=options.verbose,
               device=get_device(options.gpu))

    # save volumes
    input, output = options.echoes, options.output
    if len(output) != len(input):
        if len(output) == 1:
            if '{base}' in output[0]:
                output = [output[0]] * len(input)
        elif len(output) != len(fit):
            raise ValueError(f'There should be either one output file, '
                             f'or as many output files as input files, '
                             f'or as many output files as echoes. Got '
                             f'{len(output)} output files, {len(input)} '
                             f'input files, and {len(fit)} echoes.')
    if len(output) == 1:
        dir, base, ext = py.fileparts(input[0])
        output = output[0]
        if '{n}' in output:
            for n, echo in enumerate(fit):
                out = output.format(dir=dir,
                                    sep=os.sep,
                                    base=base,
                                    ext=ext,
                                    n=n)
                io.savef(echo, out, like=input[0])
        else:
            output = output.format(dir=dir, sep=os.sep, base=base, ext=ext)
            io.savef(torch.movedim(fit, 0, -1), output, like=input[0])
    elif len(output) == len(input):
        for i, (inp, out) in enumerate(zip(input, output)):
            dir, base, ext = py.fileparts(inp)
            out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=i)
            ne = [*io.map(inp).shape, 1][3]
            io.savef(fit[:ne].movedim(0, -1), out, like=inp)
            fit = fit[ne:]
    else:
        assert len(output) == len(fit)
        dir, base, ext = py.fileparts(input[0])
        for n, (echo, out) in enumerate(zip(fit, output)):
            out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=n)
            io.savef(echo, out, like=input[0])
Ejemplo n.º 26
0
def _read_image(data, device='cpu', could_be_ct=False):
    """ Reads image data.

    Args:
        data (string|list): Path to file, or list with image data and affine matrix.
        device (string, optional): PyTorch on CPU or GPU? Defaults to 'cpu'.
        could_be_ct (bool, optional): Could the image be a CT scan?

    Returns:
        dat (torch.tensor()): Image data.
        dim (tuple(int)): Image dimensions.
        mat (torch.tensor(double)): Affine matrix.
        fname (string): File path
        direc (string): File directory path
        nam (string): Filename
        file (io.BabelArray)
        ct (bool): Is data CT

    """
    if isinstance(data, str):
        # =================================
        # Load from file
        # =================================
        file = map(data)
        dat = file.fdata(dtype=torch.float32,
                         device=device,
                         rand=False,
                         cutoff=None)
        mat = file.affine.to(device).type(torch.float64)
        fname = file.filename()
        direc, nam = os.path.split(os.path.abspath(fname))
    else:
        # =================================
        # Data and matrix given as list
        # =================================
        # Image data
        dat = data[0]
        if not isinstance(dat, torch.Tensor):
            dat = torch.tensor(dat)
        dat = dat.float()
        dat = dat.to(device)
        dat[~torch.isfinite(dat)] = 0
        # Add some random noise
        torch.manual_seed(0)
        dat[dat > 0] += torch.rand_like(dat[dat > 0]) - 1 / 2
        # Affine matrix
        mat = data[1]
        if not isinstance(mat, torch.Tensor):
            mat = torch.tensor(mat)
        mat = mat.double().to(device)
        file = None
        fname = None
        direc = None
        nam = None
    # Get dimensions
    dim = tuple(dat.shape)
    # CT?
    if could_be_ct and _is_ct(dat):
        ct = True
    else:
        ct = False
    # Mask
    dat[~torch.isfinite(dat)] = 0.0

    return dat, dim, mat, fname, direc, nam, file, ct
Ejemplo n.º 27
0
def estimate_fwhm(dat, vx=None, verbose=0, mn=-inf, mx=inf):
    """Estimates full width at half maximum (FWHM) and noise standard
    deviation (sd) of a 2D or 3D image.

    It is assumed that the image has been generated as:
        dat = Ky + n,
    where K is Gaussian smoothing with some FWHM and n is
    additive Gaussian noise. FWHM and n are estimated.

    Parameters
    ----------
    dat : str or (*spatial) tensor
        Image data or path to nifti file
    vx : [sequence of] float, default=1
        Voxel size
    verbose : {0, 1, 2}, default=0
        Verbosity level:
            * 0: No verbosity
            * 1: Print FWHM and sd to screen
            * 2: 1 + show mask
    mn : float, optional
        Exclude values below
    mx : float, optional
        Exclude values above

    Returns
    -------
    fwhm : (dim,) tensor
        Estimated FWHM
    sd : scalar tensor
        Estimated noise standard deviation.

    References
    ----------
    ..[1] "Linked independent component analysis for multimodal data fusion."
          Appendix A
          Groves AR, Beckmann CF, Smith SM, Woolrich MW.
          Neuroimage. 2011 Feb 1;54(3):2198-217.

    """
    if isinstance(dat, str):
        dat = io.map(dat)
    if isinstance(dat, io.MappedArray):
        if vx is None:
            vx = get_voxel_size(dat.affine)
        dat = dat.fdata(rand=True, missing=0)
    dat = torch.as_tensor(dat)

    dim = dat.dim()
    if vx is None:
        vx = 1
    vx = utils.make_vector(vx, dim)
    backend = utils.backend(dat)
    # Make mask
    msk = (dat > mn).bitwise_and_(dat <= mx)
    dat = dat.masked_fill(~msk, 0)
    # TODO: we should erode the mask so that only voxels whose neighbours
    #       are in the mask are considered when computing gradients.
    if verbose >= 2:
        show_slices(msk)
    # Compute image gradient
    g = diff(dat, dim=range(dim), side='central', voxel_size=vx,
             bound='dft').abs_()
    slicer = (slice(1, -1), ) * dim
    g = g[(*slicer, None)]
    g[msk[slicer], :] = 0
    g = g.reshape([-1, dim]).sum(0, dtype=torch.double)
    # Make dat have zero mean
    dat = dat[slicer]
    dat = dat[msk[slicer]]
    x0 = dat - dat.mean()
    # Compute FWHM
    fwhm = pymath.sqrt(4 * pymath.log(2)) * x0.abs().sum(dtype=torch.double)
    fwhm = fwhm / g
    if verbose >= 1:
        print(f'FWHM={fwhm.tolist()}')
    # Compute noise standard deviation
    sx = smooth('gauss', fwhm[0], x=0, **backend)[0][0, 0, 0]
    sy = smooth('gauss', fwhm[1], x=0, **backend)[0][0, 0, 0]
    sz = 1.0
    if dim == 3:
        sz = smooth('gauss', fwhm[2], x=0, **backend)[0][0, 0, 0]
    sc = (sx * sy * sz) / dim
    sc.clamp_min_(1)
    sd = torch.sqrt(x0.square().sum(dtype=torch.double) / (x0.numel() * sc))
    if verbose >= 1:
        print(f'sd={sd.tolist()}')
    return fwhm, sd
Ejemplo n.º 28
0
def write_outputs(z, prm, options):

    # prepare filenames
    ref_native = options.input[0]
    ref_mni = options.tpm[0] if options.tpm else path_spm_prior()
    format_dict = get_format_dict(ref_native, options.output)

    # move channels to back
    backend = utils.backend(z)
    if (options.nobias_nat or options.nobias_mni or options.nobias_wrp
            or options.all_nat or options.all_mni or options.all_wrp):
        dat, _, affine = get_data(options.input, options.mask, None, 3,
                                  **backend)

    # --- native space -------------------------------------------------

    if options.prob_nat or options.all_nat:
        fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('prob.nat     ->', fname)
        io.savef(torch.movedim(z, 0, -1),
                 fname,
                 like=ref_native,
                 dtype='float32')

    if options.labels_nat or options.all_nat:
        fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('labels.nat   ->', fname)
        io.save(z.argmax(0), fname, like=ref_native, dtype='int16')

    if (options.bias_nat or options.all_nat) and options.bias:
        bias = prm['bias']
        fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('bias.nat     ->', fname)
            io.savef(torch.movedim(bias, 0, -1),
                     fname,
                     like=ref_native,
                     dtype='float32')
        else:
            for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'bias.nat.{c+1}   ->', fname)
                io.savef(bias1, fname, like=ref1, dtype='float32')
        del bias

    if (options.nobias_nat or options.all_nat) and options.bias:
        nobias = dat * prm['bias']
        fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('nobias.nat   ->', fname)
            io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native)
        else:
            for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'nobias.nat.{c+1} ->', fname)
                io.savef(nobias1, fname, like=ref1)
        del nobias

    if (options.warp_nat or options.all_nat) and options.warp:
        warp = prm['warp']
        fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.nat     ->', fname)
        io.savef(warp, fname, like=ref_native, dtype='float32')

    # --- MNI space ----------------------------------------------------
    if options.tpm is False:
        # No template -> no MNI space
        return

    fref = io.map(ref_mni)
    mni_affine, mni_shape = fref.affine, fref.shape[:3]
    dat_affine = io.map(ref_native).affine
    mni_affine = mni_affine.to(**backend)
    dat_affine = dat_affine.to(**backend)
    prm_affine = prm['affine'].to(**backend)
    dat_affine = prm_affine @ dat_affine
    if options.mni_vx:
        vx = spatial.voxel_size(mni_affine)
        scl = vx / options.mni_vx
        mni_affine, mni_shape = spatial.affine_resize(mni_affine,
                                                      mni_shape,
                                                      scl,
                                                      anchor='f')

    if options.prob_mni or options.labels_mni or options.all_mni:
        z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape)
        if options.prob_mni:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.mni     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.mni   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_mni or options.nobias_mni
                         or options.all_mni):
        bias = spatial.reslice(prm['bias'],
                               dat_affine,
                               mni_affine,
                               mni_shape,
                               interpolation=3,
                               prefilter=False,
                               bound='dct2')

        if options.bias_mni or options.all_mni:
            fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.mni     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.mni.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_mni or options.all_mni:
            nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape)
            nobias *= bias
            fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.mni   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.mni.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias

    need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp
                  or options.bias_wrp or options.nobias_wrp or options.all_mni
                  or options.all_wrp)
    need_iwarp = need_iwarp and options.warp
    if not need_iwarp:
        return

    iwarp = spatial.grid_inv(prm['warp'], type='disp')
    iwarp = iwarp.movedim(-1, 0)
    iwarp = spatial.reslice(iwarp,
                            dat_affine,
                            mni_affine,
                            mni_shape,
                            interpolation=2,
                            bound='dft',
                            extrapolate=True)
    iwarp = iwarp.movedim(0, -1)
    iaff = mni_affine.inverse() @ dat_affine
    iwarp = linalg.matvec(iaff[:3, :3], iwarp)

    if (options.warp_mni or options.all_mni) and options.warp:
        fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.mni     ->', fname)
        io.savef(iwarp,
                 fname,
                 like=ref_native,
                 affine=mni_affine,
                 dtype='float32')

    # --- Warped space -------------------------------------------------
    iwarp = spatial.add_identity_grid_(iwarp)
    iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp)

    if options.prob_wrp or options.labels_wrp or options.all_wrp:
        z_mni = spatial.grid_pull(z, iwarp)
        if options.prob_mni or options.all_wrp:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.wrp     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni or options.all_wrp:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.wrp   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_wrp or options.nobias_wrp
                         or options.all_wrp):
        bias = spatial.grid_pull(prm['bias'],
                                 iwarp,
                                 interpolation=3,
                                 prefilter=False,
                                 bound='dct2')
        if options.bias_wrp or options.all_wrp:
            fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.wrp     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.wrp.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_wrp or options.all_wrp:
            nobias = spatial.grid_pull(dat, iwarp)
            nobias *= bias
            fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.wrp   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.wrp.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias
Ejemplo n.º 29
0
def main_apply(options):
    """
    Unwarp distorted images using a pre-computed 1d displacement field.
    """
    device = get_device(options.gpu)

    # detect readout direction
    if options.file_pos:
        f0 = io.map(options.file_pos[0])
    else:
        f0 = io.map(options.file_neg[0])
    dim = f0.affine.shape[-1] - 1
    readout = get_readout(options.readout, f0.affine, f0.shape[-dim:])

    def do_apply(fnames, phi, jac):
        """Correct files with a given polarity"""
        for fname in fnames:
            dir, base, ext = py.fileparts(fname)
            ofname = options.output
            ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base,
                                   ext=ext)
            if options.verbose:
                print(f'unwarp {fname} \n'
                      f'    -> {ofname}')

            f = io.map(fname)
            d = f.fdata(device=device)
            d = utils.movedim(d, readout, -1)
            d = _deform1d(d, phi)
            if jac is not None:
                d *= jac
            d = utils.movedim(d, -1, readout)

            io.savef(d, ofname, like=fname)

    # load and apply
    vel = io.loadf(options.dist_file, device=device)
    vel = utils.movedim(vel, readout, -1)

    if options.file_pos:
        if options.diffeo:
            phi, *jac = spatial.exp1d_forward(vel, bound='dct2',
                                              jacobian=options.modulation)
            jac = jac[0] if jac else None
        else:
            phi = vel.clone()
            jac = None
            if options.modulation:
                jac = spatial.diff1d(phi, dim=readout, bound='dct2', side='c')
                jac += 1
        phi = spatial.add_identity_grid_(phi.unsqueeze(-1)).squeeze(-1)

        do_apply(options.file_pos, phi, jac)

    if options.file_neg:
        if options.diffeo:
            phi, *jac = spatial.exp1d_forward(-vel, bound='dct2',
                                              jacobian=options.modulation)
            jac = jac[0] if jac else None
        else:
            phi = -vel
            jac = None
            if options.modulation:
                jac = spatial.diff1d(phi, dim=readout, bound='dct2', side='c')
                jac += 1
        phi = spatial.add_identity_grid_(phi.unsqueeze(-1)).squeeze(-1)

        do_apply(options.file_neg, phi, jac)
Ejemplo n.º 30
0
def main_fit(options):
    """
    Estimate a displacement field from opposite polarity  images
    """
    device = get_device(options.gpu)

    # map input files
    f0 = io.map(options.pos_file)
    f1 = io.map(options.neg_file)
    dim = f0.affine.shape[-1] - 1

    # map mask
    fm = None
    if options.mask:
        fm = io.map(options.mask)

    # detect readout direction
    readout = get_readout(options.readout, f0.affine, f0.shape[-dim:])

    # detect penalty
    penalty_type = 'bending'
    penalties = options.penalty
    if penalties and isinstance(penalties[-1], str):
        *penalties, penalty_type = penalties
    if not penalties:
        penalties = [1]
    if penalty_type[0] == 'b':
        penalty_type = 'bending'
    elif penalty_type[0] == 'm':
        penalty_type = 'membrane'
    else:
        raise ValueError('Unknown penalty type', penalty_type)

    downs = options.downsample
    max_iter = options.max_iter
    tolerance = options.tolerance
    nb_levels = max(len(penalties), len(max_iter), len(tolerance), len(downs))
    penalties = py.make_list(penalties, nb_levels)
    tolerance = py.make_list(tolerance, nb_levels)
    max_iter = py.make_list(max_iter, nb_levels)
    downs = py.make_list(downs, nb_levels)

    # load
    d00 = f0.fdata(device='cpu')
    d11 = f1.fdata(device='cpu')
    dmask = fm.fdata(device='cpu') if fm else None

    # fit
    vel = mask = None
    aff = last_aff = f0.affine
    last_dwn = None
    for penalty, n, tol, dwn in zip(penalties, max_iter, tolerance, downs):
        if dwn != last_dwn:
            d0, aff = downsample(d00.to(device), f0.affine, dwn)
            d1, _ = downsample(d11.to(device), f1.affine, dwn)
            vx = spatial.voxel_size(aff)
            if vel is not None:
                vel = upsample_vel(vel, last_aff, aff, d0.shape[-dim:], readout)
            last_aff = aff
            if fm:
                mask, _ = downsample(dmask.to(device), f1.affine, dwn)
        last_dwn = dwn
        scl = py.prod(d00.shape) / py.prod(d0.shape)
        penalty = penalty * scl

        kernel = get_kernel(options.kernel, aff, d0.shape[-dim:], dwn)

        # prepare loss
        if options.loss == 'mse':
            prm0, _ = estimate_noise(d0)
            prm1, _ = estimate_noise(d1)
            sd = ((prm0['sd'].log() + prm1['sd'].log())/2).exp()
            print(sd.item())
            loss = MSE(lam=1/(sd*sd), dim=dim)
        elif options.loss == 'lncc':
            loss = LNCC(dim=dim, patch=kernel)
        elif options.loss == 'lgmm':
            if options.bins == 1:
                loss = LNCC(dim=dim, patch=kernel)
            else:
                loss = LGMMH(dim=dim, patch=kernel, bins=options.bins)
        elif options.loss == 'gmm':
            if options.bins == 1:
                loss = NCC(dim=dim)
            else:
                loss = GMMH(dim=dim, bins=options.bins)
        else:
            loss = NCC(dim=dim)

        # fit
        vel = topup_fit(d0, d1, loss=loss, dim=readout, vx=vx, ndim=dim,
                        model=('svf' if options.diffeo else 'smalldef'),
                        lam=penalty, penalty=penalty_type, vel=vel,
                        modulation=options.modulation, max_iter=n,
                        tolerance=tol, verbose=options.verbose, mask=mask)

    del d0, d1, d00, d11

    # upsample
    vel = upsample_vel(vel, aff, f0.affine, f0.shape[-dim:], readout)

    # save
    dir, base, ext = py.fileparts(options.pos_file)
    fname = options.output
    fname = fname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext)
    io.savef(vel, fname, like=options.pos_file, dtype='float32')