Example #1
0
def convert(inp,
            meta=None,
            dtype=None,
            casting='unsafe',
            format=None,
            output=None):
    """Convert a volume.

    Parameters
    ----------
    inp : str
        A path to a volume file.
    meta : sequence of (key, value)
        List of metadata fields to set.
    dtype : str or dtype, optional
        Output data type
    casting : {'unsafe', 'rescale', 'rescale_zero'}, default='unsafe'
        Casting method
    format : {'nii', 'nii.gz', 'mgh', 'mgz'}, optional
        Output format

    """

    meta = dict(meta or {})
    if dtype:
        meta['dtype'] = dtype
    fname = inp
    f = io.volumes.map(fname)
    d = f.data(numpy=True)

    dir, base, ext = py.fileparts(fname)
    if format:
        ext = format
        if ext == 'nifti':
            ext = 'nii'
        if ext[0] != '.':
            ext = '.' + ext
    output = output or '{dir}{sep}{base}{ext}'
    output = output.format(dir=dir or '.', sep=os.sep, base=base, ext=ext)

    odtype = meta.get('dtype', None) or f.dtype
    if ext in ('.mgh', '.mgz'):
        from nibabel.freesurfer.mghformat import _dtdefs
        odtype = dtypes.dtype(odtype)
        mgh_dtypes = [dtypes.dtype(dt[2]) for dt in _dtdefs]
        for mgh_dtype in mgh_dtypes:
            if odtype <= mgh_dtype:
                odtype = mgh_dtype
                break
        odtype = odtype.numpy
        meta['dtype'] = odtype

    io.save(d, output, like=f, casting=casting, **meta)
Example #2
0
    def do_apply(fnames, phi, jac):
        """Correct files with a given polarity"""
        for fname in fnames:
            dir, base, ext = py.fileparts(fname)
            ofname = options.output
            ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base,
                                   ext=ext)
            if options.verbose:
                print(f'unwarp {fname} \n'
                      f'    -> {ofname}')

            f = io.map(fname)
            d = f.fdata(device=device)
            d = utils.movedim(d, readout, -1)
            d = _deform1d(d, phi)
            if jac is not None:
                d *= jac
            d = utils.movedim(d, -1, readout)

            io.savef(d, ofname, like=fname)
Example #3
0
def inpaint(*inputs,
            missing='nan',
            output=None,
            device=None,
            verbose=1,
            max_iter_rls=10,
            max_iter_cg=32,
            tol_rls=1e-5,
            tol_cg=1e-5):
    """Inpaint missing values by minimizing Joint Total Variation.

    Parameters
    ----------
    *inputs : str or tensor or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    missing : 'nan' or scalar or callable, default='nan'
        Mask of the missing data. If a scalar, all voxels with that value
        are considered missing. If a function, it should return the mask
        of missing values when applied to the multi-channel data. Else,
        non-finite values are assumed missing.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.pool{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    verbose : int, default=1
    device : torch.device, optional
    max_iter_rls : int, default=10
    max_iter_cg : int, default=32
    tol_rls : float, default=1e-5
    tol_cg : float, default=1e-5

    Returns
    -------
    *output : str or (tensor, tensor)
        If the input is a path, the output path is returned.
        Else, the pooled data and orientation matrix are returned.

    """
    # Preprocess
    dirs = []
    bases = []
    exts = []
    fnames = []
    nchannels = []
    dat = []
    aff = None
    for i, inp in enumerate(inputs):
        is_file = isinstance(inp, str)
        if is_file:
            fname = inp
            dir, base, ext = py.fileparts(fname)
            fnames.append(inp)
            dirs.append(dir)
            bases.append(base)
            exts.append(ext)

            f = io.volumes.map(fname)
            if aff is None:
                aff = f.affine
            f = ensure_4d(f)
            dat.append(f.fdata(device=device))

        else:
            fnames.append(None)
            dirs.append('')
            bases.append(f'{i+1}')
            exts.append('.nii.gz')
            if isinstance(inp, (list, tuple)):
                if aff is None:
                    dat1, aff = inp
                else:
                    dat1, _ = inp
            else:
                dat1 = inp
            dat.append(torch.as_tensor(dat1, device=device))
            del dat1
        nchannels.append(dat[-1].shape[-1])
    dat = utils.to(*dat, dtype=torch.float, device=utils.max_device(dat))
    if not torch.is_tensor(dat):
        dat = torch.cat(dat, dim=-1)
    dat = utils.movedim(dat, -1, 0)  # (channels, *spatial)

    # Set missing data
    if missing != 'nan':
        if not callable(missing):
            missingval = utils.make_vector(missing,
                                           dtype=dat.dtype,
                                           device=dat.device)
            missing = lambda x: utils.isin(x, missingval)
        dat[missing(dat)] = nan
    dat[~torch.isfinite(dat)] = nan

    # Do it
    if aff is not None:
        vx = spatial.voxel_size(aff)
    else:
        vx = 1
    dat = do_inpaint(dat,
                     voxel_size=vx,
                     verbose=verbose,
                     max_iter_rls=max_iter_rls,
                     tol_rls=tol_rls,
                     max_iter_cg=max_iter_cg,
                     tol_cg=tol_cg)

    # Postprocess
    dat = utils.movedim(dat, 0, -1)
    dat = dat.split(nchannels, dim=-1)
    output = py.make_list(output, len(dat))
    for i in range(len(dat)):
        if fnames[i] and not output[i]:
            output[i] = '{dir}{sep}{base}.inpaint{ext}'
        if output[i]:
            if fnames[i]:
                output[i] = output[i].format(dir=dirs[i] or '.',
                                             base=bases[i],
                                             ext=exts[i],
                                             sep=os.path.sep)
                io.volumes.save(dat[i], output[i], like=fnames[i], affine=aff)
            else:
                output[i] = output[i].format(sep=os.path.sep)
                io.volumes.save(dat[i], output[i], affine=aff)

    dat = [
        output[i] if fnames[i] else
        (dat[i], aff) if aff is not None else dat[i] for i in range(len(dat))
    ]
    if len(dat) == 1:
        dat = dat[0]
    return dat
Example #4
0
def crop(inp,
         size=None,
         center=None,
         space='vx',
         like=None,
         output=None,
         transform=None):
    """Crop a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, optional
        Size of the patch to extract.
        Its unit and axes are defined by `units` and `layout`.
    center : [sequence of] int, optional
        Coordinate of the center of the patch.
        Its unit and axes are defined by `units` and `layout`.
        By default, the center of the FOV is used.
    space : [sequence of] {'vox', 'ras'}, default='vox'
        The space in which the `size` and `center` parameters are expressed.
    like : str or (tensor, tensor), optional
        Reference patch.
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Input or output filename(s) of the corresponding transforms.
        Not written by default.
        If a transform is provided and all other parameters
        (i.e., `size` and `like`) are None, the transform is considered
        as an input transform to apply.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None
    transform_in = False

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(numpy=True), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.crop{ext}'
        dir, base, ext = py.fileparts(fname)
    dat, aff0 = inp
    dim = aff0.shape[-1] - 1
    shape0 = dat.shape[:dim]

    if size and like:
        raise ValueError('Cannot use both `size` and `like`.')

    # --- Open reference and compute size/center ---
    if like:
        like_is_file = isinstance(like, str)
        if like_is_file:
            f = io.volumes.map(like)
            like = (f.shape, f.affine)
        like_shape, like_aff = like
        if torch.is_tensor(like_shape):
            like_shape = like_shape.shape
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- Open transformation file and compute size/center ---
    elif not size:
        if not transform:
            raise ValueError('At least one of size/like/transform must '
                             'be provided')
        transform_in = True
        t = io.transforms.map(transform)
        if not isinstance(t, io.transforms.LinearTransformArray):
            raise TypeError('Expected an LTA file')
        like_aff, like_shape = t.destination_space()
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- use center of the FOV ---
    if not torch.is_tensor(center) and not center:
        center = torch.as_tensor(shape0[:dim], dtype=torch.float) * 0.5

    # --- convert size/center to voxels ---
    size = utils.make_vector(size, dim, dtype=torch.long)
    center = utils.make_vector(center, dim, dtype=torch.float)
    space_size, space_center = py.make_list(space, 2)
    if space_center.lower() == 'ras':
        center = spatial.affine_matvec(spatial.affine_inv(aff0), center)
    if space_size.lower() == 'ras':
        perm = spatial.affine_to_layout(aff0)[:, 0]
        size = size[perm.long()]
        size = size / spatial.voxel_size(aff0)

    # --- compute first/last ---
    center = center.float()
    size = size.ceil().long()
    first = (center - size.float() / 2).round().long()
    last = (first + size).tolist()
    first = [max(f, 0) for f in first.tolist()]
    last = [min(l, s) for l, s in zip(last, shape0[:dim])]
    verb = 'Cropping patch ['
    verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)])
    verb += f'] from volume with shape {shape0[:dim]}'
    print(verb)
    slicer = tuple(slice(f, l) for f, l in zip(first, last))

    # --- do crop ---
    dat = dat[slicer]
    aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer)
    shape = dat.shape[:dim]

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if transform and not transform_in:
        if is_file:
            transform = transform.format(dir=dir or '.',
                                         base=base,
                                         ext=ext,
                                         sep=os.path.sep)
        else:
            transform = transform.format(sep=os.path.sep)
        trf = io.transforms.LinearTransformArray(transform, 'w')
        trf.set_source_space(aff0, shape0)
        trf.set_destination_space(aff, shape)
        trf.set_metadata({
            'src': {
                'filename': fname
            },
            'dst': {
                'filename': output
            },
            'type': 1
        })  # RAS_TO_RAS
        trf.set_fdata(torch.eye(4))
        trf.save()

    if is_file:
        return output
    else:
        return dat, aff
Example #5
0
def extract_patches(inp, size=64, stride=None, output=None, transform=None):
    """Extracgt patches from a 3D volume.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, default=64
        Patch size.
    stride : [sequence of] int, default=size
        Stride between patches.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}_{j}_{k}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Output filename(s) of the corresponding transforms.
        Not written by default.

    Returns
    -------
    output : list[str] or (tensor, tensor)
        If the input is a path, the output paths are returned.
        Else, the unfolded data and orientation matrices are returned.
            Data will have shape (nx, ny, nz, *size, *channels).
            Affines will have shape (nx, ny, nz, 4, 4).

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{i}_{j}_{k}{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp

    shape = dat.shape[:3]
    size = py.make_list(size, 3)
    stride = py.make_list(stride, 3)
    stride = [st or sz for st, sz in zip(stride, size)]

    dat = utils.movedim(dat, [0, 1, 2], [-3, -2, -1])
    dat = utils.unfold(dat, size, stride)
    dat = utils.movedim(dat, [-6, -5, -4, -3, -2, -1], [0, 1, 2, 3, 4, 5])

    aff = aff0.new_empty(dat.shape[:3] + aff0.shape)
    for i in range(dat.shape[0]):
        for j in range(dat.shape[1]):
            for k in range(dat.shape[2]):
                index = (i, j, k)
                sub = [slice(st*idx, st*idx + sz)
                       for st, sz, idx in zip(stride, size, index)]
                aff[i, j, k], _ = spatial.affine_sub(aff0, shape, tuple(sub))

    formatted_output = []
    if output:
        output = py.make_list(output, py.prod(dat.shape[:3]))
        formatted_output = []
        for i in range(dat.shape[0]):
            for j in range(dat.shape[1]):
                for k in range(dat.shape[2]):
                    out1 = output.pop(0)
                    if is_file:
                        out1 = out1.format(dir=dir or '.', base=base, ext=ext,
                                           sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                        io.volumes.savef(dat[i, j, k], out1, like=fname,
                                         affine=aff[i, j, k])
                    else:
                        out1 = out1.format(sep=os.path.sep, i=i, j=j, k=k)
                        io.volumes.savef(dat[i, j, k], out1, affine=aff[i, j, k])
                    formatted_output.append(out1)

    if transform:
        transform = py.make_list(transform, py.prod(dat.shape[:3]))
        for i in range(dat.shape[0]):
            for j in range(dat.shape[1]):
                for k in range(dat.shape[2]):
                    trf1 = transform.pop(0)
                    if is_file:
                        trf1 = trf1.format(dir=dir or '.', base=base, ext=ext,
                                           sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                    else:
                        trf1 = trf1.format(sep=os.path.sep, i=i+1, j=j+1, k=k+1)
                    io.transforms.savef(torch.eye(4), trf1,
                                        source=aff0, target=aff[i, j, k])

    if is_file:
        return formatted_output
    else:
        return dat, aff
Example #6
0
def orient(inp, affine=None, layout=None, voxel_size=None, center=None,
           like=None, output=None, output_transform=None):
    """Overwrite the orientation matrix

    Parameters
    ----------
    inp : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    affine : {'self', 'like'} or (4, 4) tensor_like, default='like'
        Target affine matrix
    layout : {'self', 'like'} or layout-like, default='like'
        Target orientation.
    voxel_size : {'self', 'like'} or [sequence of] float, default='like'
        Target voxel size.
    center : {'self', 'like'} or [sequence of] float, default='like'
        World coordinate of the center of the field of view.
    like : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    output : str, optional
        Output filename.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.
    output_transform : str, optional
        Filename of output transform.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}_to_{layout}.lta', where `dir` and `base`
        are the directory and base name of the input file.

    Returns
    -------
    output : str or (tuple, tensor)
        If the input is a path, the output path is returned.
        Else, the new shape and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        dim = f.affine.shape[-1] - 1
        inp = (f.shape[:dim], f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{layout}{ext}'
        if output_transform is None:
            output_transform = '{dir}{sep}{base}_to_{layout}.lta'
        dir, base, ext = py.fileparts(fname)

    like_is_file = isinstance(like, str) and like
    if like_is_file:
        f = io.volumes.map(like)
        dim = f.affine.shape[-1] - 1
        like = (f.shape[:dim], f.affine)

    shape, aff0 = inp
    dim = aff0.shape[-1] - 1
    if like:
        shape_like, aff_like = like
    else:
        shape_like, aff_like = (shape, aff0)

    if voxel_size in (None, 'like') or len(voxel_size) == 0:
        voxel_size = spatial.voxel_size(aff_like)
    elif voxel_size == 'self':
        voxel_size = spatial.voxel_size(aff0)
    elif voxel_size == 'standard':
        voxel_size = 1.
    voxel_size = utils.make_vector(voxel_size, dim)

    if not layout or layout == 'like':
        layout = spatial.affine_to_layout(aff_like)
    elif layout == 'self':
        layout = spatial.affine_to_layout(aff0)
    elif layout == 'standard':
        layout = 'RAS'
    layout = spatial.volume_layout(layout)

    if center in (None, 'like') or len(center) == 0:
        center = (torch.as_tensor(shape_like, dtype=torch.float) - 1) * 0.5
        center = spatial.affine_matvec(aff_like, center)
    elif center == 'self':
        center = (torch.as_tensor(shape, dtype=torch.float) - 1) * 0.5
        center = spatial.affine_matvec(aff0, center)
    elif center == 'standard':
        center = 0.
    center = utils.make_vector(center, dim)

    if affine in (None, 'like') or len(affine) == 0:
        affine = aff_like
    elif affine == 'self':
        affine = aff0
    elif affine == 'standard':
        affine = torch.eye(dim+1, dim+1)
    affine = torch.as_tensor(affine, dtype=torch.float)
    if affine.numel() == dim*(dim+1):
        affine = spatial.affine_make_rect(affine.reshape(dim, dim+1))
    elif affine.numel() == (dim+1)**2:
        affine = affine.reshape(dim+1, dim+1)
    else:
        raise ValueError(f'Input affine should have {dim*(dim+1)} or '
                         f'{(dim+1)**2} element but got {affine.numel()}.')

    affine = spatial.affine_modify(affine, shape, voxel_size=voxel_size,
                                   layout=layout, center=center)
    affine = affine.double()

    if output:
        dat = io.volumes.load(fname, numpy=True)
        layout = spatial.volume_layout_to_name(layout)
        if is_file:
            output = output.format(dir=dir or '.', base=base, ext=ext,
                                   sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, like=fname, affine=affine)
        else:
            output = output.format(sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, affine=affine)

    if output_transform:
        transform = spatial.affine_rmdiv(affine, aff0)
        output_transform = output_transform.format(
            dir=dir or '.', base=base, sep=os.path.sep, layout=layout)
        io.transforms.savef(transform.cpu(), output_transform, type=2)

    if is_file:
        return output
    else:
        return shape, affine
Example #7
0
def _main(options):
    if isinstance(options.gpu, str):
        device = torch.device(options.gpu)
    else:
        assert isinstance(options.gpu, int)
        device = torch.device(f'cuda:{options.gpu}')
    if not torch.cuda.is_available():
        device = 'cpu'

    # prepare options
    estatics_opt = ESTATICSOptions()
    estatics_opt.likelihood = options.likelihood
    estatics_opt.verbose = options.verbose >= 1
    estatics_opt.plot = options.verbose >= 2
    estatics_opt.recon.space = options.space
    if isinstance(options.space, str) and  options.space != 'mean':
        for c, contrast in enumerate(options.contrast):
            if contrast.name == options.space:
                estatics_opt.recon.space = c
                break
    estatics_opt.backend.device = device
    estatics_opt.optim.nb_levels = options.levels
    estatics_opt.optim.max_iter_rls = options.iter
    estatics_opt.optim.tolerance = options.tol
    estatics_opt.regularization.norm = options.regularization
    estatics_opt.regularization.factor = [*options.lam_intercept, options.lam_decay]
    estatics_opt.distortion.enable = options.meetup
    estatics_opt.distortion.bending = options.lam_meetup
    estatics_opt.preproc.register = options.register

    # prepare files
    contrasts = []
    distortion = []
    for i, c in enumerate(options.contrast):

        # read meta-parameters
        meta = {}
        if c.te:
            te, unit = c.te, ''
            if isinstance(te[-1], str):
                *te, unit = te
            if unit:
                if unit == 'ms':
                    te = [t * 1e-3 for t in te]
                elif unit not in ('s', 'sec'):
                    raise ValueError(f'TE unit: {unit}')
            if c.echo_spacing:
                delta, *unit = c.echo_spacing
                unit = unit[0] if unit else ''
                if unit == 'ms':
                    delta = delta * 1e-3
                elif unit not in ('s', 'sec'):
                    raise ValueError(f'echo spacing unit: {unit}')
                ne = sum(io.map(f).unsqueeze(-1).shape[3] for f in c.echoes)
                te = [te[0] + e*delta for e in range(ne)]
            meta['te'] = te

        # map volumes
        contrasts.append(qio.GradientEchoMulti.from_fname(c.echoes, **meta))

        if c.readout:
            layout = spatial.affine_to_layout(contrasts[-1].affine)
            layout = spatial.volume_layout_to_name(layout)
            readout = None
            for j, l in enumerate(layout):
                if l.lower() in c.readout.lower():
                    readout = j - 3
            contrasts[-1].readout = readout

        if c.b0:
            bw = c.bandwidth
            b0, *unit = c.b0
            unit = unit[-1] if unit else 'vx'
            fb0 = b0.map(b0)
            b0 = fb0.fdata(device=device)
            b0 = spatial.reslice(b0, fb0.affine, contrasts[-1][0].affine,
                                 contrasts[-1][0].shape)
            if unit.lower() == 'hz':
                if not bw:
                    raise ValueError('Bandwidth required to convert fieldmap'
                                     'from Hz to voxel')
                b0 /= bw
            b0 = DenseDistortion(b0)
            distortion.append(b0)
        else:
            distortion.append(None)

    # run algorithm
    [te0, r2s, *b0] = estatics(contrasts, distortion, opt=estatics_opt)

    # write results

    # --- intercepts ---
    odir0 = options.odir
    for i, te1 in enumerate(te0):
        ifname = contrasts[i].echo(0).volume.fname
        odir, obase, oext = py.fileparts(ifname)
        odir = odir0 or odir
        obase = obase + '_TE0'
        ofname = os.path.join(odir, obase + oext)
        io.savef(te1.volume, ofname, affine=te1.affine, like=ifname, te=0, dtype='float32')

    # --- decay ---
    ifname = contrasts[0].echo(0).volume.fname
    odir, obase, oext = py.fileparts(ifname)
    odir = odir0 or odir
    io.savef(r2s.volume, os.path.join(odir, 'R2star' + oext), affine=r2s.affine, dtype='float32')

    # --- fieldmap + undistorted ---
    if b0:
        b0 = b0[0]
        for i, b01 in enumerate(b0):
            ifname = contrasts[i].echo(0).volume.fname
            odir, obase, oext = py.fileparts(ifname)
            odir = odir0 or odir
            obase = obase + '_B0'
            ofname = os.path.join(odir, obase + oext)
            io.savef(b01.volume, ofname, affine=b01.affine, like=ifname, te=0, dtype='float32')
        for i, (c, b) in enumerate(zip(contrasts, b0)):
            readout = c.readout
            grid_up, grid_down, jac_up, jac_down = b.exp2(
                add_identity=True, jacobian=True)
            for j, e in enumerate(c):
                blip = e.blip or (2*(j % 2) - 1)
                grid_blip = grid_down if blip > 0 else grid_up  # inverse of
                jac_blip = jac_down if blip > 0 else jac_up     # forward model
                ifname = e.volume.fname
                odir, obase, oext = py.fileparts(ifname)
                odir = odir0 or odir
                obase = obase + '_unwrapped'
                ofname = os.path.join(odir, obase + oext)
                d = e.fdata(device=device)
                d, _ = pull1d(d, grid_blip, readout)
                d *= jac_blip
                io.savef(d, ofname, affine=e.affine, like=ifname)
                del d
            del grid_up, grid_down, jac_up, jac_down
    if options.register:
        for i, c in enumerate(contrasts):
            for j, e in enumerate(c):
                ifname = e.volume.fname
                odir, obase, oext = py.fileparts(ifname)
                odir = odir0 or odir
                obase = obase + '_registered'
                ofname = os.path.join(odir, obase + oext)
                io.save(e.volume, ofname, affine=e.affine)
Example #8
0
def pool(inp,
         window=3,
         stride=None,
         method='mean',
         dim=3,
         output=None,
         device=None):
    """Pool a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    window : [sequence of] int, default=3
        Window size
    stride : [sequence of] int, optional
        Stride between output elements.
        By default, it is the same as `window`.
    method : {'mean', 'sum', 'min', 'max', 'median'}, default='mean'
        Pooling function.
    dim : int, default=3
        Number of spatial dimensions.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.pool{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.

    Returns
    -------
    output : str or (tensor, tensor)
        If the input is a path, the output path is returned.
        Else, the pooled data and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(device=device), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.pool{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp
    dat = dat.to(device)
    dim = dim or aff0.shape[-1] - 1

    # `pool` needs the spatial dimensions at the end
    spatial_in = dat.shape[:dim]
    batch = dat.shape[dim:]
    dat = dat.reshape([*spatial_in, -1])
    dat = utils.movedim(dat, -1, 0)
    dat, aff = spatial.pool(dim,
                            dat,
                            kernel_size=window,
                            stride=stride,
                            reduction=method,
                            affine=aff0)
    dat = utils.movedim(dat, 0, -1)
    dat = dat.reshape([*dat.shape[:dim], *batch])

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if is_file:
        return output
    else:
        return dat, aff
Example #9
0
def main_fit(options):
    """
    Estimate a displacement field from opposite polarity  images
    """
    device = get_device(options.gpu)

    # map input files
    f0 = io.map(options.pos_file)
    f1 = io.map(options.neg_file)
    dim = f0.affine.shape[-1] - 1

    # map mask
    fm = None
    if options.mask:
        fm = io.map(options.mask)

    # detect readout direction
    readout = get_readout(options.readout, f0.affine, f0.shape[-dim:])

    # detect penalty
    penalty_type = 'bending'
    penalties = options.penalty
    if penalties and isinstance(penalties[-1], str):
        *penalties, penalty_type = penalties
    if not penalties:
        penalties = [1]
    if penalty_type[0] == 'b':
        penalty_type = 'bending'
    elif penalty_type[0] == 'm':
        penalty_type = 'membrane'
    else:
        raise ValueError('Unknown penalty type', penalty_type)

    downs = options.downsample
    max_iter = options.max_iter
    tolerance = options.tolerance
    nb_levels = max(len(penalties), len(max_iter), len(tolerance), len(downs))
    penalties = py.make_list(penalties, nb_levels)
    tolerance = py.make_list(tolerance, nb_levels)
    max_iter = py.make_list(max_iter, nb_levels)
    downs = py.make_list(downs, nb_levels)

    # load
    d00 = f0.fdata(device='cpu')
    d11 = f1.fdata(device='cpu')
    dmask = fm.fdata(device='cpu') if fm else None

    # fit
    vel = mask = None
    aff = last_aff = f0.affine
    last_dwn = None
    for penalty, n, tol, dwn in zip(penalties, max_iter, tolerance, downs):
        if dwn != last_dwn:
            d0, aff = downsample(d00.to(device), f0.affine, dwn)
            d1, _ = downsample(d11.to(device), f1.affine, dwn)
            vx = spatial.voxel_size(aff)
            if vel is not None:
                vel = upsample_vel(vel, last_aff, aff, d0.shape[-dim:], readout)
            last_aff = aff
            if fm:
                mask, _ = downsample(dmask.to(device), f1.affine, dwn)
        last_dwn = dwn
        scl = py.prod(d00.shape) / py.prod(d0.shape)
        penalty = penalty * scl

        kernel = get_kernel(options.kernel, aff, d0.shape[-dim:], dwn)

        # prepare loss
        if options.loss == 'mse':
            prm0, _ = estimate_noise(d0)
            prm1, _ = estimate_noise(d1)
            sd = ((prm0['sd'].log() + prm1['sd'].log())/2).exp()
            print(sd.item())
            loss = MSE(lam=1/(sd*sd), dim=dim)
        elif options.loss == 'lncc':
            loss = LNCC(dim=dim, patch=kernel)
        elif options.loss == 'lgmm':
            if options.bins == 1:
                loss = LNCC(dim=dim, patch=kernel)
            else:
                loss = LGMMH(dim=dim, patch=kernel, bins=options.bins)
        elif options.loss == 'gmm':
            if options.bins == 1:
                loss = NCC(dim=dim)
            else:
                loss = GMMH(dim=dim, bins=options.bins)
        else:
            loss = NCC(dim=dim)

        # fit
        vel = topup_fit(d0, d1, loss=loss, dim=readout, vx=vx, ndim=dim,
                        model=('svf' if options.diffeo else 'smalldef'),
                        lam=penalty, penalty=penalty_type, vel=vel,
                        modulation=options.modulation, max_iter=n,
                        tolerance=tol, verbose=options.verbose, mask=mask)

    del d0, d1, d00, d11

    # upsample
    vel = upsample_vel(vel, aff, f0.affine, f0.shape[-dim:], readout)

    # save
    dir, base, ext = py.fileparts(options.pos_file)
    fname = options.output
    fname = fname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext)
    io.savef(vel, fname, like=options.pos_file, dtype='float32')
Example #10
0
def _main(options):
    device = setup_device(*options.device)
    dim = 3

    # ------------------------------------------------------------------
    #                       COMPUTE PYRAMID
    # ------------------------------------------------------------------
    pyramids = _prepare_pyramid_levels(options.loss, options.pyramid, dim)

    # ------------------------------------------------------------------
    #                       BUILD LOSSES
    # ------------------------------------------------------------------
    loss_list, image_dict = _build_losses(options, pyramids, device)

    can_use_2nd_order = all(loss.loss.order >= 2 for loss in loss_list)

    # ------------------------------------------------------------------
    #                           BUILD AFFINE
    # ------------------------------------------------------------------
    affine, affine_optim = _build_affine(options, can_use_2nd_order)

    # ------------------------------------------------------------------
    #                           BUILD DENSE
    # ------------------------------------------------------------------
    nonlin, nonlin_optim = _build_nonlin(options, can_use_2nd_order, affine,
                                         image_dict)

    if not affine and not nonlin:
        raise ValueError('At least one of @affine or @nonlin must be used.')

    # ------------------------------------------------------------------
    #                           BACKEND STUFF
    # ------------------------------------------------------------------
    if options.verbose > 1:
        import matplotlib
        matplotlib.use('TkAgg')

    # local losses may benefit from selecting the best conv
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    # ------------------------------------------------------------------
    #                      PERFORM REGISTRATION
    # ------------------------------------------------------------------
    _do_register(loss_list, affine, nonlin, affine_optim, nonlin_optim,
                 options)

    # ------------------------------------------------------------------
    #                           WRITE RESULTS
    # ------------------------------------------------------------------
    if affine:
        affine = affine[-1]

    if affine and options.affine.output:
        odir = options.odir or py.fileparts(
            options.loss[0].fix.files[0])[0] or '.'
        fname = options.affine.output.format(dir=odir,
                                             sep=os.path.sep,
                                             name=options.affine.name)
        print('Affine ->', fname)
        aff = affine.exp(cache_result=True, recompute=False)
        io.transforms.savef(aff.cpu(), fname, type=1)  # 1 = RAS_TO_RAS
    if nonlin and options.nonlin.output:
        odir = options.odir or py.fileparts(
            options.loss[0].fix.files[0])[0] or '.'
        fname = options.nonlin.output.format(dir=odir,
                                             sep=os.path.sep,
                                             name=options.nonlin.name)
        io.savef(nonlin.dat.dat, fname, affine=nonlin.affine)
        print('Nonlin ->', fname)
    for loss in options.loss:
        _warp_image(loss,
                    affine=affine,
                    nonlin=nonlin,
                    dim=dim,
                    device=device,
                    odir=options.odir)
Example #11
0
def vexp(inp,
         type='displacement',
         unit='voxel',
         inverse=False,
         bound='dft',
         steps=8,
         device=None,
         output=None):
    """Exponentiate a stationary velocity fields.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    type : {'displacement', 'transformation'}, default='displacement'
        Whether to return a displacement field (coord-to-shift) or a
        transformation field (coord-to-coord).
    unit : {'voxel', 'mm'}, default='voxel'
        Whether to return displacement/coordinates in voxel or in mm.
        If mm, the input orientation matrix is used to convert voxels to mm.
    inverse : bool, default=False
        Whether to return the inverse field.
    bound : str, default='dft'
        Boundary conditions.
    steps : int, default=8
        Number of scaling and squaring steps.
    device : str, optional
        Device to use.
    output : str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.vexp{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.

    Returns
    -------
    output : str or (tensor, tensor)
        If the input is a path, the output path is returned.
        Else, the output tensor and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.fdata(device=device), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.vexp{ext}'
        dir, base, ext = py.fileparts(fname)
    else:
        if torch.is_tensor(inp):
            inp = (inp.clone(), spatial.affine_default(shape=inp.shape[:3]))
    dat, aff = inp
    dat = dat.to(device=device)
    aff = aff.to(device=device)

    # exponentiate
    dat = spatial.exp(dat[None],
                      inverse=inverse,
                      steps=steps,
                      bound=bound,
                      inplace=True,
                      displacement=(type.lower()[0] == 'd'))[0]
    if unit == 'mm':
        # if type.lower()[0] == 'd':
        #     vx = spatial.voxel_size(aff)
        #     dat *= vx
        # else:
        dat = spatial.affine_matvec(aff, dat)

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff.cpu())
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff.cpu())

    if is_file:
        return output
    else:
        return dat, aff
Example #12
0
def main(options):

    # find readout direction
    f = io.map(options.echoes[0])
    affine, shape = f.affine, f.shape
    readout = get_readout(options.direction, affine, shape, options.verbose)

    if not options.reversed:
        reversed_echoes = options.synth
    else:
        reversed_echoes = options.reversed

    # do EPIC
    fit = epic(options.echoes,
               reverse_echoes=reversed_echoes,
               fieldmap=options.fieldmap,
               extrapolate=options.extrapolate,
               bandwidth=options.bandwidth,
               polarity=options.polarity,
               readout=readout,
               slicewise=options.slicewise,
               lam=options.penalty,
               max_iter=options.maxiter,
               tol=options.tolerance,
               verbose=options.verbose,
               device=get_device(options.gpu))

    # save volumes
    input, output = options.echoes, options.output
    if len(output) != len(input):
        if len(output) == 1:
            if '{base}' in output[0]:
                output = [output[0]] * len(input)
        elif len(output) != len(fit):
            raise ValueError(f'There should be either one output file, '
                             f'or as many output files as input files, '
                             f'or as many output files as echoes. Got '
                             f'{len(output)} output files, {len(input)} '
                             f'input files, and {len(fit)} echoes.')
    if len(output) == 1:
        dir, base, ext = py.fileparts(input[0])
        output = output[0]
        if '{n}' in output:
            for n, echo in enumerate(fit):
                out = output.format(dir=dir,
                                    sep=os.sep,
                                    base=base,
                                    ext=ext,
                                    n=n)
                io.savef(echo, out, like=input[0])
        else:
            output = output.format(dir=dir, sep=os.sep, base=base, ext=ext)
            io.savef(torch.movedim(fit, 0, -1), output, like=input[0])
    elif len(output) == len(input):
        for i, (inp, out) in enumerate(zip(input, output)):
            dir, base, ext = py.fileparts(inp)
            out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=i)
            ne = [*io.map(inp).shape, 1][3]
            io.savef(fit[:ne].movedim(0, -1), out, like=inp)
            fit = fit[ne:]
    else:
        assert len(output) == len(fit)
        dir, base, ext = py.fileparts(input[0])
        for n, (echo, out) in enumerate(zip(fit, output)):
            out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=n)
            io.savef(echo, out, like=input[0])
Example #13
0
def orient(inp,
           layout=None,
           voxel_size=None,
           center=None,
           like=None,
           output=None):
    """Overwrite the orientation matrix

    Parameters
    ----------
    inp : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    layout : str or layout-like, default=None (= preserve)
        Target orientation.
    voxel_size : [sequence of] float, default=None (= preserve)
        Target voxel size.
    center : [sequence of] float, default=None (= preserve)
        World coordinate of the center of the field of view.
    like : str or (tuple, tensor)
        Either a path to a volume file or a tuple `(shape, affine)`, where
        the first element contains the volume shape and the second contains
        the orientation matrix.
    output : str, optional
        Output filename.
        If the input is not a path, the reoriented data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{layout}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file.

    Returns
    -------
    output : str or (tuple, tensor)
        If the input is a path, the output path is returned.
        Else, the new shape and orientation matrix are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        dim = f.affine.shape[-1] - 1
        inp = (f.shape[:dim], f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{layout}{ext}'
        dir, base, ext = py.fileparts(fname)

    like_is_file = isinstance(like, str) and like
    if like_is_file:
        f = io.volumes.map(like)
        dim = f.affine.shape[-1] - 1
        like = (f.shape[:dim], f.affine)

    shape, aff0 = inp
    dim = aff0.shape[-1] - 1
    if like:
        shape_like, aff_like = like
    else:
        shape_like, aff_like = (shape, aff0)

    if voxel_size in (None, 'like') or len(voxel_size) == 0:
        voxel_size = spatial.voxel_size(aff_like)
    elif voxel_size == 'self':
        voxel_size = spatial.voxel_size(aff0)
    voxel_size = utils.make_vector(voxel_size, dim)

    if not layout or layout == 'like':
        layout = spatial.affine_to_layout(aff_like)
    elif layout == 'self':
        layout = spatial.affine_to_layout(aff0)
    layout = spatial.volume_layout(layout)

    if center in (None, 'like') or len(voxel_size) == 0:
        center = torch.as_tensor(shape_like, dtype=torch.float) * 0.5
        center = spatial.affine_matvec(aff_like, center)
    elif center == 'self':
        center = torch.as_tensor(shape, dtype=torch.float) * 0.5
        center = spatial.affine_matvec(aff0, center)

    center = utils.make_vector(center, dim)

    aff = spatial.affine_default(shape,
                                 voxel_size=voxel_size,
                                 layout=layout,
                                 center=center,
                                 dtype=torch.double)

    if output:
        dat = io.volumes.load(fname, numpy=True)
        layout = spatial.volume_layout_to_name(layout)
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   layout=layout)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep, layout=layout)
            io.volumes.save(dat, output, affine=aff)

    if is_file:
        return output
    else:
        return shape, aff
Example #14
0
def crop(inp,
         size=None,
         center=None,
         space='vx',
         like=None,
         bbox=False,
         output=None,
         transform=None):
    """Crop a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    size : [sequence of] int, optional
        Size of the patch to extract.
        Its unit and axes are defined by `units` and `layout`.
    center : [sequence of] int, optional
        Coordinate of the center of the patch.
        Its unit and axes are defined by `units` and `layout`.
        By default, the center of the FOV is used.
    space : [sequence of] {'vox', 'ras'}, default='vox'
        The space in which the `size` and `center` parameters are expressed.
    bbox : bool or float, default=False
        Crop at the bounding box of `inp > threshold`.
            If `bbox` is a float, it is the threshold to use.
            If `bbox` is `True`, the threshold is 0.
    like : str or (tensor, tensor), optional
        Reference patch.
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Input or output filename(s) of the corresponding transforms.
        Not written by default.
        If a transform is provided and all other parameters
        (i.e., `size` and `like`) are None, the transform is considered
        as an input transform to apply.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = None
    transform_in = False
    use_bbox = bool(bbox or isinstance(bbox, float))

    # --- Open input ---
    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(numpy=True) if use_bbox else f, f.affine)
        if output is None:
            output = '{dir}{sep}{base}.crop{ext}'
        dir, base, ext = py.fileparts(fname)
    dat, aff0 = inp
    dim = aff0.shape[-1] - 1
    shape0 = dat.shape[:dim]
    layout0 = spatial.affine_to_layout(aff0)

    # save input space in case we reorient later
    aff00 = aff0
    shape00 = shape0

    if bool(size) + bool(like) + bool(bbox or isinstance(bbox, float)) > 1:
        raise ValueError('Can only use one of `size`, `like` and `bbox`.')

    # --- Open reference and compute size/center ---
    if like:
        like_is_file = isinstance(like, str)
        if like_is_file:
            f = io.volumes.map(like)
            like = (f.shape, f.affine)
        like_shape, like_aff = like
        like_layout = spatial.affine_to_layout(like_aff)
        if (layout0 != like_layout).any():
            aff0, dat = spatial.affine_reorient(aff0, dat, like_layout)
            shape0 = dat.shape[:dim]
        if torch.is_tensor(like_shape):
            like_shape = like_shape.shape
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)
        space = 'vox'

    elif bbox or isinstance(bbox, float):
        if bbox is True:
            bbox = 0.
        mask = torch.as_tensor(dat > bbox)
        while mask.dim() > 3:
            mask = mask.any(dim=-1)
        mins = []
        maxs = []
        for d in range(dim):
            n = mask.shape[d]
            idx = utils.movedim(mask, d,
                                0).reshape([n, -1
                                            ]).any(-1).nonzero(as_tuple=False)
            mins.append(idx.min())
            maxs.append(idx.max())
        mins = utils.as_tensor(mins)
        maxs = utils.as_tensor(maxs)
        size = maxs + 1 - mins
        center = (maxs + 1 + mins).float() / 2
        space = 'vox'
        del mask

    # --- Open transformation file and compute size/center ---
    elif not size:
        if not transform:
            raise ValueError('At least one of size/like/transform must '
                             'be provided')
        transform_in = True
        t = io.transforms.map(transform)
        if not isinstance(t, io.transforms.LinearTransformArray):
            raise TypeError('Expected an LTA file')
        like_aff, like_shape = t.destination_space()
        size, center, unit, layout = _crop_to_param(aff0, like_aff, like_shape)

    # --- use center of the FOV ---
    if not torch.is_tensor(center) and not center:
        center = torch.as_tensor(shape0[:dim], dtype=torch.float)
        center = center.sub_(1).mul_(0.5)

    # --- convert size/center to voxels ---
    size = utils.make_vector(size, dim, dtype=torch.long)
    center = utils.make_vector(center, dim, dtype=torch.float)
    space_size, space_center = py.make_list(space, 2)
    if space_center.lower() == 'ras':
        center = spatial.affine_matvec(spatial.affine_inv(aff0), center)
    if space_size.lower() == 'ras':
        perm = spatial.affine_to_layout(aff0)[:, 0]
        size = size[perm.long()]
        size = size / spatial.voxel_size(aff0)

    # --- compute first/last ---
    center = center.float()
    size = (size.ceil() if size.dtype.is_floating_point else size).long()
    first = center - size.float().sub_(1).mul_(0.5)
    first = first.round().long()
    last = (first + size).tolist()
    first = [max(f, 0) for f in first.tolist()]
    last = [min(l, s) for l, s in zip(last, shape0[:dim])]
    verb = 'Cropping patch ['
    verb += ', '.join([f'{f}:{l}' for f, l in zip(first, last)])
    verb += f'] from volume with shape {shape0[:dim]}'
    print(verb)
    slicer = tuple(slice(f, l) for f, l in zip(first, last))

    # --- do crop ---
    if use_bbox:
        dat = dat.numpy()
    dat = dat[slicer]
    if not torch.is_tensor(dat):
        dat = dat.data(numpy=True)
    aff, _ = spatial.affine_sub(aff0, shape0[:dim], slicer)
    shape = dat.shape[:dim]

    if output:
        if is_file:
            output = output.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep)
            io.volumes.save(dat, output, like=fname, affine=aff)
        else:
            output = output.format(sep=os.path.sep)
            io.volumes.save(dat, output, affine=aff)

    if transform and not transform_in:
        if is_file:
            transform = transform.format(dir=dir or '.',
                                         base=base,
                                         ext=ext,
                                         sep=os.path.sep)
        else:
            transform = transform.format(sep=os.path.sep)
        trf = io.transforms.LinearTransformArray(transform, 'w')
        trf.set_source_space(aff00, shape00)
        trf.set_destination_space(aff, shape)
        trf.set_metadata({
            'src': {
                'filename': fname
            },
            'dst': {
                'filename': output
            },
            'type': 1
        })  # RAS_TO_RAS
        trf.set_fdata(torch.eye(4))
        trf.save()

    if is_file:
        return output
    else:
        return dat, aff
Example #15
0
def _warp_image(option,
                affine=None,
                nonlin=None,
                dim=None,
                device=None,
                odir=None):
    """Warp and save the moving and fixed images from a loss object"""

    if not (option.mov.output or option.mov.resliced or option.fix.output
            or option.fix.resliced):
        return

    fix, fix_affine = _map_image(option.fix.files, dim=dim)
    mov, mov_affine = _map_image(option.mov.files, dim=dim)
    fix_affine = fix_affine.float()
    mov_affine = mov_affine.float()
    dim = dim or (fix.dim - 1)

    if option.fix.world:  # overwrite orientation matrix
        fix_affine = io.transforms.map(option.fix.world).fdata().squeeze()
    for transform in (option.fix.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        fix_affine = spatial.affine_lmdiv(transform, fix_affine)

    if option.mov.world:  # overwrite orientation matrix
        mov_affine = io.transforms.map(option.mov.world).fdata().squeeze()
    for transform in (option.mov.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        mov_affine = spatial.affine_lmdiv(transform, mov_affine)

    # moving
    if option.mov.output or option.mov.resliced:
        ifname = option.mov.files[0]
        idir, base, ext = py.fileparts(ifname)
        odir_mov = odir or idir or '.'

        image = objects.Image(mov.fdata(rand=True, device=device),
                              dim=dim,
                              affine=mov_affine,
                              bound=option.mov.bound,
                              extrapolate=option.mov.extrapolate)

        if option.mov.output:
            target_affine = mov_affine
            target_shape = image.shape
            if affine and affine.position[0].lower() in 'ms':
                aff = affine.exp(recompute=False, cache_result=True)
                target_affine = spatial.affine_lmdiv(aff, target_affine)

            fname = option.mov.output.format(dir=odir_mov,
                                             base=base,
                                             sep=os.path.sep,
                                             ext=ext)
            print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

        if option.mov.resliced:
            target_affine = fix_affine
            target_shape = fix.shape[1:]

            fname = option.mov.resliced.format(dir=odir_mov,
                                               base=base,
                                               sep=os.path.sep,
                                               ext=ext)
            print(f'Full reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  reslice=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

    # fixed
    if option.fix.output or option.fix.resliced:
        ifname = option.fix.files[0]
        idir, base, ext = py.fileparts(ifname)
        odir_fix = odir or idir or '.'

        image = objects.Image(fix.fdata(rand=True, device=device),
                              dim=dim,
                              affine=fix_affine,
                              bound=option.fix.bound,
                              extrapolate=option.fix.extrapolate)

        if option.fix.output:
            target_affine = fix_affine
            target_shape = image.shape
            if affine and affine.position[0].lower() in 'fs':
                aff = affine.exp(recompute=False, cache_result=True)
                target_affine = spatial.affine_matmul(aff, target_affine)

            fname = option.fix.output.format(dir=odir_fix,
                                             base=base,
                                             sep=os.path.sep,
                                             ext=ext)
            print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  backward=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

        if option.fix.resliced:
            target_affine = mov_affine
            target_shape = mov.shape[1:]

            fname = option.fix.resliced.format(dir=odir_fix,
                                               base=base,
                                               sep=os.path.sep,
                                               ext=ext)
            print(f'Full reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  backward=True,
                                  reslice=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped
Example #16
0
def unstack(inp, dim=-1, output=None, transform=None):
    """Unstack a ND volume, while preserving the orientation matrices.

    Parameters
    ----------
    inp : str or (tensor, tensor)
        Either a path to a volume file or a tuple `(dat, affine)`, where
        the first element contains the volume data and the second contains
        the orientation matrix.
    dim : int, default=-1
        Dimension along which to unstack.
    output : [sequence of] str, optional
        Output filename(s).
        If the input is not a path, the unstacked data is not written
        on disk by default.
        If the input is a path, the default output filename is
        '{dir}/{base}.{i}{ext}', where `dir`, `base` and `ext`
        are the directory, base name and extension of the input file,
        `i` is the coordinate (starting at 1) of the slice.
    transform : [sequence of] str, optional
        Output filename(s) of the corresponding transforms.
        Not written by default.

    Returns
    -------
    output : list[str or (tensor, tensor)]
        If the input is a path, the output paths are returned.
        Else, the unstacked data and orientation matrices are returned.

    """
    dir = ''
    base = ''
    ext = ''
    fname = ''

    is_file = isinstance(inp, str)
    if is_file:
        fname = inp
        f = io.volumes.map(inp)
        inp = (f.data(), f.affine)
        if output is None:
            output = '{dir}{sep}{base}.{i}{ext}'
        dir, base, ext = py.fileparts(fname)

    dat, aff0 = inp
    ndim = aff0.shape[-1] - 1
    if dim > ndim:
        # we didn't touch the spatial dimensions
        aff = [aff0.clone() for _ in range(len(dat))]
    else:
        aff = []
        slicer = [slice(None) for _ in range(ndim)]
        shape = dat.shape[:ndim]
        for z in range(dat.shape[dim]):
            slicer[dim] = slice(z, z + 1)
            aff1, _ = spatial.affine_sub(aff0, shape, tuple(slicer))
            aff.append(aff1)
    dat = torch.unbind(dat, dim)
    dat = [d.unsqueeze(dim) for d in dat]
    dat = list(zip(dat, aff))

    formatted_output = []
    if output:
        output = py.make_list(output, len(dat))
        formatted_output = []
        for i, ((dat1, aff1), out1) in enumerate(zip(dat, output)):
            if is_file:
                out1 = out1.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   i=i + 1)
                io.volumes.save(dat1, out1, like=fname, affine=aff1)
            else:
                out1 = out1.format(sep=os.path.sep, i=i + 1)
                io.volumes.save(dat1, out1, affine=aff1)
            formatted_output.append(out1)

    if transform:
        transform = py.make_list(transform, len(dat))
        for i, ((_, aff1), trf1) in enumerate(zip(dat, transform)):
            if is_file:
                trf1 = trf1.format(dir=dir or '.',
                                   base=base,
                                   ext=ext,
                                   sep=os.path.sep,
                                   i=i + 1)
            else:
                trf1 = trf1.format(sep=os.path.sep, i=i + 1)
            io.transforms.savef(torch.eye(4), trf1, source=aff0, target=aff1)

    if is_file:
        return formatted_output
    else:
        return dat, aff
Example #17
0
def get_format_dict(fname, dir0):
    dir, base, ext = py.fileparts(fname)
    dir = dir0 or dir or '.'
    return dict(dir=dir, sep=os.sep, base=base, ext=ext)