Exemple #1
0
def _write_image(dat,
                 fname,
                 bids=False,
                 mat=torch.eye(4),
                 file=None,
                 dtype='float32'):
    """ Write data to nifti.
    """
    if bids:
        p, n = os.path.split(fname)
        s = n.split('_')
        fname = os.path.join(p, '_'.join(s[:-1] + ['space-unires'] + [s[-1]]))

    savef(dat, fname, like=file, affine=mat)
Exemple #2
0
def _write_output(dat, mat, file=None, prefix='', odir=None, nam=None):
    """Write preprocessed output to disk.
    """
    if odir is not None:
        os.makedirs(odir, exist_ok=True)
    pth = []
    for n in range(len(dat)):
        if file[n] is not None:
            filename = file[n].filename()
        else:
            filename = 'nitorch_file'
        pth.append(file_mod(filename, odir=odir, prefix=prefix, nam=nam))
        savef(dat[n], pth[n], like=file[n], affine=mat[n])
    if len(dat) == 1:
        pth = pth[0]

    return pth
Exemple #3
0
    def do_apply(fnames, phi, jac):
        """Correct files with a given polarity"""
        for fname in fnames:
            dir, base, ext = py.fileparts(fname)
            ofname = options.output
            ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base,
                                   ext=ext)
            if options.verbose:
                print(f'unwarp {fname} \n'
                      f'    -> {ofname}')

            f = io.map(fname)
            d = f.fdata(device=device)
            d = utils.movedim(d, readout, -1)
            d = _deform1d(d, phi)
            if jac is not None:
                d *= jac
            d = utils.movedim(d, -1, readout)

            io.savef(d, ofname, like=fname)
Exemple #4
0
def _cli(args):
    """Command-line interface for `smooth` without exception handling"""
    args = args or sys.argv[1:]

    options = parser(args)
    if options.help:
        print(help)
        return

    fwhm = options.fwhm
    unit = 'mm'
    if isinstance(fwhm[-1], str):
        *fwhm, unit = fwhm
    fwhm = make_list(fwhm, 3)

    options.output = make_list(options.output, len(options.files))
    for fname, ofname in zip(options.files, options.output):
        f = io.map(fname)
        vx = voxel_size(f.affine).tolist()
        dim = len(vx)
        if unit == 'mm':
            fwhm1 = [f / v for f, v in zip(fwhm, vx)]
        else:
            fwhm1 = fwhm[:len(vx)]

        dat = f.fdata()
        dat = movedim_front2back(dat, dim)
        dat = smooth(dat,
                     type=options.method,
                     fwhm=fwhm1,
                     basis=options.basis,
                     bound=options.padding,
                     dim=dim)
        dat = movedim_back2front(dat, dim)

        folder, base, ext = fileparts(fname)
        ofname = ofname.format(dir=folder or '.',
                               base=base,
                               ext=ext,
                               sep=os.path.sep)
        io.savef(dat, ofname, like=f)
Exemple #5
0
def _main(options):
    if isinstance(options.gpu, str):
        device = torch.device(options.gpu)
    else:
        assert isinstance(options.gpu, int)
        device = torch.device(f'cuda:{options.gpu}')
    if not torch.cuda.is_available():
        device = 'cpu'

    # prepare options
    estatics_opt = ESTATICSOptions()
    estatics_opt.likelihood = options.likelihood
    estatics_opt.verbose = options.verbose >= 1
    estatics_opt.plot = options.verbose >= 2
    estatics_opt.recon.space = options.space
    if isinstance(options.space, str) and  options.space != 'mean':
        for c, contrast in enumerate(options.contrast):
            if contrast.name == options.space:
                estatics_opt.recon.space = c
                break
    estatics_opt.backend.device = device
    estatics_opt.optim.nb_levels = options.levels
    estatics_opt.optim.max_iter_rls = options.iter
    estatics_opt.optim.tolerance = options.tol
    estatics_opt.regularization.norm = options.regularization
    estatics_opt.regularization.factor = [*options.lam_intercept, options.lam_decay]
    estatics_opt.distortion.enable = options.meetup
    estatics_opt.distortion.bending = options.lam_meetup
    estatics_opt.preproc.register = options.register

    # prepare files
    contrasts = []
    distortion = []
    for i, c in enumerate(options.contrast):

        # read meta-parameters
        meta = {}
        if c.te:
            te, unit = c.te, ''
            if isinstance(te[-1], str):
                *te, unit = te
            if unit:
                if unit == 'ms':
                    te = [t * 1e-3 for t in te]
                elif unit not in ('s', 'sec'):
                    raise ValueError(f'TE unit: {unit}')
            if c.echo_spacing:
                delta, *unit = c.echo_spacing
                unit = unit[0] if unit else ''
                if unit == 'ms':
                    delta = delta * 1e-3
                elif unit not in ('s', 'sec'):
                    raise ValueError(f'echo spacing unit: {unit}')
                ne = sum(io.map(f).unsqueeze(-1).shape[3] for f in c.echoes)
                te = [te[0] + e*delta for e in range(ne)]
            meta['te'] = te

        # map volumes
        contrasts.append(qio.GradientEchoMulti.from_fname(c.echoes, **meta))

        if c.readout:
            layout = spatial.affine_to_layout(contrasts[-1].affine)
            layout = spatial.volume_layout_to_name(layout)
            readout = None
            for j, l in enumerate(layout):
                if l.lower() in c.readout.lower():
                    readout = j - 3
            contrasts[-1].readout = readout

        if c.b0:
            bw = c.bandwidth
            b0, *unit = c.b0
            unit = unit[-1] if unit else 'vx'
            fb0 = b0.map(b0)
            b0 = fb0.fdata(device=device)
            b0 = spatial.reslice(b0, fb0.affine, contrasts[-1][0].affine,
                                 contrasts[-1][0].shape)
            if unit.lower() == 'hz':
                if not bw:
                    raise ValueError('Bandwidth required to convert fieldmap'
                                     'from Hz to voxel')
                b0 /= bw
            b0 = DenseDistortion(b0)
            distortion.append(b0)
        else:
            distortion.append(None)

    # run algorithm
    [te0, r2s, *b0] = estatics(contrasts, distortion, opt=estatics_opt)

    # write results

    # --- intercepts ---
    odir0 = options.odir
    for i, te1 in enumerate(te0):
        ifname = contrasts[i].echo(0).volume.fname
        odir, obase, oext = py.fileparts(ifname)
        odir = odir0 or odir
        obase = obase + '_TE0'
        ofname = os.path.join(odir, obase + oext)
        io.savef(te1.volume, ofname, affine=te1.affine, like=ifname, te=0, dtype='float32')

    # --- decay ---
    ifname = contrasts[0].echo(0).volume.fname
    odir, obase, oext = py.fileparts(ifname)
    odir = odir0 or odir
    io.savef(r2s.volume, os.path.join(odir, 'R2star' + oext), affine=r2s.affine, dtype='float32')

    # --- fieldmap + undistorted ---
    if b0:
        b0 = b0[0]
        for i, b01 in enumerate(b0):
            ifname = contrasts[i].echo(0).volume.fname
            odir, obase, oext = py.fileparts(ifname)
            odir = odir0 or odir
            obase = obase + '_B0'
            ofname = os.path.join(odir, obase + oext)
            io.savef(b01.volume, ofname, affine=b01.affine, like=ifname, te=0, dtype='float32')
        for i, (c, b) in enumerate(zip(contrasts, b0)):
            readout = c.readout
            grid_up, grid_down, jac_up, jac_down = b.exp2(
                add_identity=True, jacobian=True)
            for j, e in enumerate(c):
                blip = e.blip or (2*(j % 2) - 1)
                grid_blip = grid_down if blip > 0 else grid_up  # inverse of
                jac_blip = jac_down if blip > 0 else jac_up     # forward model
                ifname = e.volume.fname
                odir, obase, oext = py.fileparts(ifname)
                odir = odir0 or odir
                obase = obase + '_unwrapped'
                ofname = os.path.join(odir, obase + oext)
                d = e.fdata(device=device)
                d, _ = pull1d(d, grid_blip, readout)
                d *= jac_blip
                io.savef(d, ofname, affine=e.affine, like=ifname)
                del d
            del grid_up, grid_down, jac_up, jac_down
    if options.register:
        for i, c in enumerate(contrasts):
            for j, e in enumerate(c):
                ifname = e.volume.fname
                odir, obase, oext = py.fileparts(ifname)
                odir = odir0 or odir
                obase = obase + '_registered'
                ofname = os.path.join(odir, obase + oext)
                io.save(e.volume, ofname, affine=e.affine)
Exemple #6
0
def _warp_image(option,
                affine=None,
                nonlin=None,
                dim=None,
                device=None,
                odir=None):
    """Warp and save the moving and fixed images from a loss object"""

    if not (option.mov.output or option.mov.resliced or option.fix.output
            or option.fix.resliced):
        return

    fix, fix_affine = _map_image(option.fix.files, dim=dim)
    mov, mov_affine = _map_image(option.mov.files, dim=dim)
    fix_affine = fix_affine.float()
    mov_affine = mov_affine.float()
    dim = dim or (fix.dim - 1)

    if option.fix.world:  # overwrite orientation matrix
        fix_affine = io.transforms.map(option.fix.world).fdata().squeeze()
    for transform in (option.fix.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        fix_affine = spatial.affine_lmdiv(transform, fix_affine)

    if option.mov.world:  # overwrite orientation matrix
        mov_affine = io.transforms.map(option.mov.world).fdata().squeeze()
    for transform in (option.mov.affine or []):
        transform = io.transforms.map(transform).fdata().squeeze()
        mov_affine = spatial.affine_lmdiv(transform, mov_affine)

    # moving
    if option.mov.output or option.mov.resliced:
        ifname = option.mov.files[0]
        idir, base, ext = py.fileparts(ifname)
        odir_mov = odir or idir or '.'

        image = objects.Image(mov.fdata(rand=True, device=device),
                              dim=dim,
                              affine=mov_affine,
                              bound=option.mov.bound,
                              extrapolate=option.mov.extrapolate)

        if option.mov.output:
            target_affine = mov_affine
            target_shape = image.shape
            if affine and affine.position[0].lower() in 'ms':
                aff = affine.exp(recompute=False, cache_result=True)
                target_affine = spatial.affine_lmdiv(aff, target_affine)

            fname = option.mov.output.format(dir=odir_mov,
                                             base=base,
                                             sep=os.path.sep,
                                             ext=ext)
            print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

        if option.mov.resliced:
            target_affine = fix_affine
            target_shape = fix.shape[1:]

            fname = option.mov.resliced.format(dir=odir_mov,
                                               base=base,
                                               sep=os.path.sep,
                                               ext=ext)
            print(f'Full reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  reslice=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

    # fixed
    if option.fix.output or option.fix.resliced:
        ifname = option.fix.files[0]
        idir, base, ext = py.fileparts(ifname)
        odir_fix = odir or idir or '.'

        image = objects.Image(fix.fdata(rand=True, device=device),
                              dim=dim,
                              affine=fix_affine,
                              bound=option.fix.bound,
                              extrapolate=option.fix.extrapolate)

        if option.fix.output:
            target_affine = fix_affine
            target_shape = image.shape
            if affine and affine.position[0].lower() in 'fs':
                aff = affine.exp(recompute=False, cache_result=True)
                target_affine = spatial.affine_matmul(aff, target_affine)

            fname = option.fix.output.format(dir=odir_fix,
                                             base=base,
                                             sep=os.path.sep,
                                             ext=ext)
            print(f'Minimal reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  backward=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped

        if option.fix.resliced:
            target_affine = mov_affine
            target_shape = mov.shape[1:]

            fname = option.fix.resliced.format(dir=odir_fix,
                                               base=base,
                                               sep=os.path.sep,
                                               ext=ext)
            print(f'Full reslice: {ifname} -> {fname} ...', end=' ')
            warped = _warp_image1(image,
                                  target_affine,
                                  target_shape,
                                  affine=affine,
                                  nonlin=nonlin,
                                  backward=True,
                                  reslice=True)
            io.savef(warped, fname, like=ifname, affine=target_affine)
            print('done.')
            del warped
Exemple #7
0
def _main(options):
    device = setup_device(*options.device)
    dim = 3

    # ------------------------------------------------------------------
    #                       COMPUTE PYRAMID
    # ------------------------------------------------------------------
    pyramids = _prepare_pyramid_levels(options.loss, options.pyramid, dim)

    # ------------------------------------------------------------------
    #                       BUILD LOSSES
    # ------------------------------------------------------------------
    loss_list, image_dict = _build_losses(options, pyramids, device)

    can_use_2nd_order = all(loss.loss.order >= 2 for loss in loss_list)

    # ------------------------------------------------------------------
    #                           BUILD AFFINE
    # ------------------------------------------------------------------
    affine, affine_optim = _build_affine(options, can_use_2nd_order)

    # ------------------------------------------------------------------
    #                           BUILD DENSE
    # ------------------------------------------------------------------
    nonlin, nonlin_optim = _build_nonlin(options, can_use_2nd_order, affine,
                                         image_dict)

    if not affine and not nonlin:
        raise ValueError('At least one of @affine or @nonlin must be used.')

    # ------------------------------------------------------------------
    #                           BACKEND STUFF
    # ------------------------------------------------------------------
    if options.verbose > 1:
        import matplotlib
        matplotlib.use('TkAgg')

    # local losses may benefit from selecting the best conv
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    # ------------------------------------------------------------------
    #                      PERFORM REGISTRATION
    # ------------------------------------------------------------------
    _do_register(loss_list, affine, nonlin, affine_optim, nonlin_optim,
                 options)

    # ------------------------------------------------------------------
    #                           WRITE RESULTS
    # ------------------------------------------------------------------
    if affine:
        affine = affine[-1]

    if affine and options.affine.output:
        odir = options.odir or py.fileparts(
            options.loss[0].fix.files[0])[0] or '.'
        fname = options.affine.output.format(dir=odir,
                                             sep=os.path.sep,
                                             name=options.affine.name)
        print('Affine ->', fname)
        aff = affine.exp(cache_result=True, recompute=False)
        io.transforms.savef(aff.cpu(), fname, type=1)  # 1 = RAS_TO_RAS
    if nonlin and options.nonlin.output:
        odir = options.odir or py.fileparts(
            options.loss[0].fix.files[0])[0] or '.'
        fname = options.nonlin.output.format(dir=odir,
                                             sep=os.path.sep,
                                             name=options.nonlin.name)
        io.savef(nonlin.dat.dat, fname, affine=nonlin.affine)
        print('Nonlin ->', fname)
    for loss in options.loss:
        _warp_image(loss,
                    affine=affine,
                    nonlin=nonlin,
                    dim=dim,
                    device=device,
                    odir=options.odir)
Exemple #8
0
def denoise_mri(*dat_x,
                affine_x=None,
                lam_scl=opt.lam_scl,
                lr=opt.learning_rate,
                max_iter=opt.max_iter,
                tolerance=opt.tolerance,
                verbose=opt.verbose,
                device=opt.device,
                do_write=opt.do_write,
                dir_out=opt.dir_out):
    """Denoises a multi-channel MR image by solving:

    dat_y_hat = 0.5*sum_c(tau_c*sum_i((dat_x_ci - dat_y_ci)^2)) + jtv(dat_y_1, ..., dat_y_C; lam)

    using PyTorch's auto-diff.

    If input is given as paths to files, outputs prefixed 'den_' is written based on options
    'dir_out' and 'do_write'.

    Reference:
    Brudfors, Mikael, et al. "MRI super-resolution using multi-channel total variation."
    Annual Conference on Medical Image Understanding and Analysis. Springer, Cham, 2018.

    Parameters
    ----------
    dat_x : (nchannels, dmx, dmy, dmz) tensor or sequence[str]
        Input noisy image data
    affine_x : (4, 4) tensor, optional
        Input images' affine matrix. If not given, assumes identity.
    lam_scl : float, default=10.0
        Scaling of regularisation values
    lr : float, default=1e1
        Optimiser learning rate
    max_iter : int, default=10000
        Maximum number of fitting iterations
    tolerance : float, default=1e-8
        Convergence threshold (when to stop iterating)
    verbose : bool, default=True
        Print to terminal?
    device : torch.device, default='cuda'
        Torch device
    do_write : bool, default=True
        If input is given as paths to files, output is written to disk,
        prefixed 'den_' to 'dir_out'
    dir_out : str, optional
        Directory where to write output, default is same as input.

    Returns
    ----------
    dat_y_hat : (nchannels, dmx, dmy, dmz) tensor
        Denoised image data

    """
    # read data from disk
    if isinstance(dat_x,(list, tuple)) and \
        sum(isinstance(dat_x[i],str) for i in range(len(dat_x))) == len(dat_x):
        dat_x, affine_x, nii = _get_image_data(dat_x, device=device)
    else:
        do_write = False  # input is tensor, do not write to disk
    # backend
    device = dat_x.device
    dtype = dat_x.dtype
    # estimate hyper-parameters
    tau = torch.zeros(dat_x.shape[0], device=device, dtype=dtype)
    lam = torch.zeros(dat_x.shape[0], device=device, dtype=dtype)
    for i in range(dat_x.shape[0]):
        prm0, prm1 = estimate_noise(dat_x[i, ...], show_fit=False)
        sd_bg = prm0['sd']
        mean_fg = prm1['mean']
        tau[i] = 1 / sd_bg.float()**2
        lam[i] = math.sqrt(1 / dat_x.shape[0]) / mean_fg.float(
        )  # modulates with number of channels (as in JTV reg)
    # print("tau={:}".format(tau))
    # print("lam={:}".format(lam))
    # affine matrices
    if affine_x is None:
        affine_x = torch.eye(4, device=device, dtype=dtype)
    # voxel size
    vx = voxel_size(affine_x)
    # scale regularisation
    lam = lam_scl * lam
    # initial estimate of reconstruction
    dat_y_hat = torch.zeros_like(dat_x)
    dat_y_hat = torch.nn.Parameter(dat_y_hat, requires_grad=True)
    # prepare optimiser and scheduler
    optim = torch.optim.Adam([dat_y_hat], lr=lr)  # Adam
    # optim = torch.optim.SGD([dat_y_hat], lr=lr, momentum=0.9)  # SGD
    scheduler = ReduceLROnPlateau(optim)
    # optimisation loop
    loss_vals = torch.zeros(max_iter + 1, dtype=torch.float64)
    cnt_conv = 0
    for n_iter in range(1, max_iter + 1):
        # set gradients to zero (PyTorch accumulates the gradients on subsequent backward passes)
        optim.zero_grad()
        # compute reconstruction loss
        loss_val = _loss_ssqd_jtv(dat_x, dat_y_hat, tau, lam, vx=vx)
        # differentiate reconstruction loss w.r.t. dat_y_hat
        loss_val.backward()
        # store loss
        loss_vals[n_iter] = loss_val.item()
        # update reconstruction
        optim.step()
        # compute gain
        gain = get_gain(loss_vals[:n_iter + 1], monotonicity='decreasing')
        if verbose:
            # print to screen
            with torch.no_grad():
                if n_iter % 25 == 0:
                    print('n_iter={:4d}, loss={:12.6f}, gain={:0.10}, lr={:g}'. \
                        format(n_iter, loss_val.item(), gain, optim.param_groups[0]['lr']), end='\n')  # end='\r'
        if n_iter > 10 and gain.abs() < tolerance:
            cnt_conv += 1
            if cnt_conv == 5:
                # finished
                break
        else:
            cnt_conv = 0
        # incorporate scheduler
        if scheduler is not None and n_iter % 10 == 0:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(loss_val)
            else:
                scheduler.step()
    if do_write:
        # write output to disk
        if dir_out is not None:
            os.makedirs(dir_out, exist_ok=True)
        for i in range(dat_y_hat.shape[0]):
            fname = file_replace(nii.fname,
                                 prefix='den_',
                                 dir=dir_out,
                                 suffix='_' + str(i))
            savef(dat_y_hat[i, ...], fname, like=nii)

    return dat_y_hat
Exemple #9
0
def reslice(moving, fname, like, inv=False, lin=None, nonlin=None,
           interpolation=1, bound='dct2', extrapolate=False, device=None,
           verbose=True):
    """Apply the linear and non-linear components of the transform and
    reslice the image to the target space.

    Notes
    -----
    .. The shape and general orientation of the moving image is kept untouched.
    .. The linear transform is composed with the original orientation matrix.
    .. The non-linear component is "wrapped" in the input space, where it
       is applied.
    .. This function writes a new file (it does not modify the input files
       in place).

    Parameters
    ----------
    moving : ImageFile
        An object describing a moving image.
    fname : list of str
        Output filename for each input file of the moving image
        (since images can be encoded over multiple volumes)
    like : ImageFile
        An object describing the target space
    inv : bool, default=False
        True if we are warping the fixed image to the moving space.
        In the case, `moving` should be a `FixedImageFile` and `like` a
        `MovingImageFile`. Else it should be a `MovingImageFile` and `'like`
        a `FixedImageFile`.
    lin : (4, 4) tensor, optional
        Linear (or rather affine) transformation
    nonlin : dict, optional
        Non-linear displacement field, with keys:
        disp : (..., 3) tensor
            Displacement field (in voxels)
        affine : (4, 4) tensor
            Orientation matrix of the displacement field
    interpolation : int, default=1
    bound : str, default='dct2'
    extrapolate : bool, default = False
    device : default='cpu'

    """
    nonlin = nonlin or dict(disp=None, affine=None)
    prm = dict(interpolation=interpolation, bound=bound, extrapolate=extrapolate)

    moving_affine = moving.affine.to(device)
    fixed_affine = like.affine.to(device)

    if inv:
        # affine-corrected fixed space
        if lin is not None:
            fix2lin = affine_matmul(lin, fixed_affine)
        else:
            fix2lin = fixed_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fix2lin)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param to moving
            nlin2mov = affine_lmdiv(moving_affine, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(moving_affine, fix2lin)
            g = affine_grid(g, like.shape)

    else:
        # affine-corrected moving space
        if lin is not None:
            mov2nlin = affine_matmul(lin, moving_affine)
        else:
            mov2nlin = moving_affine

        if nonlin['disp'] is not None:
            # fixed voxels to param voxels (warps param to fixed)
            fix2nlin = affine_lmdiv(nonlin['affine'].to(device), fixed_affine)
            if samespace(fix2nlin, nonlin['disp'].shape[:-1], like.shape):
                g = smalldef(nonlin['disp'].to(device))
            else:
                g = affine_grid(fix2nlin, like.shape)
                g += pull_grid(nonlin['disp'].to(device), g)
            # param voxels to moving voxels (warps moving to fixed)
            nlin2mov = affine_lmdiv(mov2nlin, nonlin['affine'].to(device))
            g = affine_matvec(nlin2mov, g)
        else:
            g = affine_lmdiv(mov2nlin, fixed_affine)
            g = affine_grid(g, like.shape)

    if moving.type == 'labels':
        prm['interpolation'] = 0
    for file, ofname in zip(moving.files, fname):
        if verbose:
            print(f'Resliced:   {file.fname}\n'
                  f'         -> {ofname}')
        dat = io.volumes.loadf(file.fname, rand=True, device=device)
        dat = dat.reshape([*file.shape, file.channels])
        if g is not None:
            dat = utils.movedim(dat, -1, 0)
            dat = pull(dat, g, **prm)
            dat = utils.movedim(dat, 0, -1)
        io.savef(dat.cpu(), ofname, like=file.fname, affine=like.affine.cpu())
Exemple #10
0
def main(options):

    # find readout direction
    f = io.map(options.echoes[0])
    affine, shape = f.affine, f.shape
    readout = get_readout(options.direction, affine, shape, options.verbose)

    if not options.reversed:
        reversed_echoes = options.synth
    else:
        reversed_echoes = options.reversed

    # do EPIC
    fit = epic(options.echoes,
               reverse_echoes=reversed_echoes,
               fieldmap=options.fieldmap,
               extrapolate=options.extrapolate,
               bandwidth=options.bandwidth,
               polarity=options.polarity,
               readout=readout,
               slicewise=options.slicewise,
               lam=options.penalty,
               max_iter=options.maxiter,
               tol=options.tolerance,
               verbose=options.verbose,
               device=get_device(options.gpu))

    # save volumes
    input, output = options.echoes, options.output
    if len(output) != len(input):
        if len(output) == 1:
            if '{base}' in output[0]:
                output = [output[0]] * len(input)
        elif len(output) != len(fit):
            raise ValueError(f'There should be either one output file, '
                             f'or as many output files as input files, '
                             f'or as many output files as echoes. Got '
                             f'{len(output)} output files, {len(input)} '
                             f'input files, and {len(fit)} echoes.')
    if len(output) == 1:
        dir, base, ext = py.fileparts(input[0])
        output = output[0]
        if '{n}' in output:
            for n, echo in enumerate(fit):
                out = output.format(dir=dir,
                                    sep=os.sep,
                                    base=base,
                                    ext=ext,
                                    n=n)
                io.savef(echo, out, like=input[0])
        else:
            output = output.format(dir=dir, sep=os.sep, base=base, ext=ext)
            io.savef(torch.movedim(fit, 0, -1), output, like=input[0])
    elif len(output) == len(input):
        for i, (inp, out) in enumerate(zip(input, output)):
            dir, base, ext = py.fileparts(inp)
            out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=i)
            ne = [*io.map(inp).shape, 1][3]
            io.savef(fit[:ne].movedim(0, -1), out, like=inp)
            fit = fit[ne:]
    else:
        assert len(output) == len(fit)
        dir, base, ext = py.fileparts(input[0])
        for n, (echo, out) in enumerate(zip(fit, output)):
            out = out.format(dir=dir, sep=os.sep, base=base, ext=ext, n=n)
            io.savef(echo, out, like=input[0])
Exemple #11
0
def main_fit(options):
    """
    Estimate a displacement field from opposite polarity  images
    """
    device = get_device(options.gpu)

    # map input files
    f0 = io.map(options.pos_file)
    f1 = io.map(options.neg_file)
    dim = f0.affine.shape[-1] - 1

    # map mask
    fm = None
    if options.mask:
        fm = io.map(options.mask)

    # detect readout direction
    readout = get_readout(options.readout, f0.affine, f0.shape[-dim:])

    # detect penalty
    penalty_type = 'bending'
    penalties = options.penalty
    if penalties and isinstance(penalties[-1], str):
        *penalties, penalty_type = penalties
    if not penalties:
        penalties = [1]
    if penalty_type[0] == 'b':
        penalty_type = 'bending'
    elif penalty_type[0] == 'm':
        penalty_type = 'membrane'
    else:
        raise ValueError('Unknown penalty type', penalty_type)

    downs = options.downsample
    max_iter = options.max_iter
    tolerance = options.tolerance
    nb_levels = max(len(penalties), len(max_iter), len(tolerance), len(downs))
    penalties = py.make_list(penalties, nb_levels)
    tolerance = py.make_list(tolerance, nb_levels)
    max_iter = py.make_list(max_iter, nb_levels)
    downs = py.make_list(downs, nb_levels)

    # load
    d00 = f0.fdata(device='cpu')
    d11 = f1.fdata(device='cpu')
    dmask = fm.fdata(device='cpu') if fm else None

    # fit
    vel = mask = None
    aff = last_aff = f0.affine
    last_dwn = None
    for penalty, n, tol, dwn in zip(penalties, max_iter, tolerance, downs):
        if dwn != last_dwn:
            d0, aff = downsample(d00.to(device), f0.affine, dwn)
            d1, _ = downsample(d11.to(device), f1.affine, dwn)
            vx = spatial.voxel_size(aff)
            if vel is not None:
                vel = upsample_vel(vel, last_aff, aff, d0.shape[-dim:], readout)
            last_aff = aff
            if fm:
                mask, _ = downsample(dmask.to(device), f1.affine, dwn)
        last_dwn = dwn
        scl = py.prod(d00.shape) / py.prod(d0.shape)
        penalty = penalty * scl

        kernel = get_kernel(options.kernel, aff, d0.shape[-dim:], dwn)

        # prepare loss
        if options.loss == 'mse':
            prm0, _ = estimate_noise(d0)
            prm1, _ = estimate_noise(d1)
            sd = ((prm0['sd'].log() + prm1['sd'].log())/2).exp()
            print(sd.item())
            loss = MSE(lam=1/(sd*sd), dim=dim)
        elif options.loss == 'lncc':
            loss = LNCC(dim=dim, patch=kernel)
        elif options.loss == 'lgmm':
            if options.bins == 1:
                loss = LNCC(dim=dim, patch=kernel)
            else:
                loss = LGMMH(dim=dim, patch=kernel, bins=options.bins)
        elif options.loss == 'gmm':
            if options.bins == 1:
                loss = NCC(dim=dim)
            else:
                loss = GMMH(dim=dim, bins=options.bins)
        else:
            loss = NCC(dim=dim)

        # fit
        vel = topup_fit(d0, d1, loss=loss, dim=readout, vx=vx, ndim=dim,
                        model=('svf' if options.diffeo else 'smalldef'),
                        lam=penalty, penalty=penalty_type, vel=vel,
                        modulation=options.modulation, max_iter=n,
                        tolerance=tol, verbose=options.verbose, mask=mask)

    del d0, d1, d00, d11

    # upsample
    vel = upsample_vel(vel, aff, f0.affine, f0.shape[-dim:], readout)

    # save
    dir, base, ext = py.fileparts(options.pos_file)
    fname = options.output
    fname = fname.format(dir=dir or '.', sep=os.sep, base=base, ext=ext)
    io.savef(vel, fname, like=options.pos_file, dtype='float32')
Exemple #12
0
 def savef(self, fname, *args, **kwargs):
     """Save to disk"""
     io.savef(self.volume, fname, *args, **kwargs)
Exemple #13
0
def write_outputs(z, prm, options):

    # prepare filenames
    ref_native = options.input[0]
    ref_mni = options.tpm[0] if options.tpm else path_spm_prior()
    format_dict = get_format_dict(ref_native, options.output)

    # move channels to back
    backend = utils.backend(z)
    if (options.nobias_nat or options.nobias_mni or options.nobias_wrp
            or options.all_nat or options.all_mni or options.all_wrp):
        dat, _, affine = get_data(options.input, options.mask, None, 3,
                                  **backend)

    # --- native space -------------------------------------------------

    if options.prob_nat or options.all_nat:
        fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('prob.nat     ->', fname)
        io.savef(torch.movedim(z, 0, -1),
                 fname,
                 like=ref_native,
                 dtype='float32')

    if options.labels_nat or options.all_nat:
        fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('labels.nat   ->', fname)
        io.save(z.argmax(0), fname, like=ref_native, dtype='int16')

    if (options.bias_nat or options.all_nat) and options.bias:
        bias = prm['bias']
        fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('bias.nat     ->', fname)
            io.savef(torch.movedim(bias, 0, -1),
                     fname,
                     like=ref_native,
                     dtype='float32')
        else:
            for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'bias.nat.{c+1}   ->', fname)
                io.savef(bias1, fname, like=ref1, dtype='float32')
        del bias

    if (options.nobias_nat or options.all_nat) and options.bias:
        nobias = dat * prm['bias']
        fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('nobias.nat   ->', fname)
            io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native)
        else:
            for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'nobias.nat.{c+1} ->', fname)
                io.savef(nobias1, fname, like=ref1)
        del nobias

    if (options.warp_nat or options.all_nat) and options.warp:
        warp = prm['warp']
        fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.nat     ->', fname)
        io.savef(warp, fname, like=ref_native, dtype='float32')

    # --- MNI space ----------------------------------------------------
    if options.tpm is False:
        # No template -> no MNI space
        return

    fref = io.map(ref_mni)
    mni_affine, mni_shape = fref.affine, fref.shape[:3]
    dat_affine = io.map(ref_native).affine
    mni_affine = mni_affine.to(**backend)
    dat_affine = dat_affine.to(**backend)
    prm_affine = prm['affine'].to(**backend)
    dat_affine = prm_affine @ dat_affine
    if options.mni_vx:
        vx = spatial.voxel_size(mni_affine)
        scl = vx / options.mni_vx
        mni_affine, mni_shape = spatial.affine_resize(mni_affine,
                                                      mni_shape,
                                                      scl,
                                                      anchor='f')

    if options.prob_mni or options.labels_mni or options.all_mni:
        z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape)
        if options.prob_mni:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.mni     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.mni   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_mni or options.nobias_mni
                         or options.all_mni):
        bias = spatial.reslice(prm['bias'],
                               dat_affine,
                               mni_affine,
                               mni_shape,
                               interpolation=3,
                               prefilter=False,
                               bound='dct2')

        if options.bias_mni or options.all_mni:
            fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.mni     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.mni.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_mni or options.all_mni:
            nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape)
            nobias *= bias
            fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.mni   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.mni.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias

    need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp
                  or options.bias_wrp or options.nobias_wrp or options.all_mni
                  or options.all_wrp)
    need_iwarp = need_iwarp and options.warp
    if not need_iwarp:
        return

    iwarp = spatial.grid_inv(prm['warp'], type='disp')
    iwarp = iwarp.movedim(-1, 0)
    iwarp = spatial.reslice(iwarp,
                            dat_affine,
                            mni_affine,
                            mni_shape,
                            interpolation=2,
                            bound='dft',
                            extrapolate=True)
    iwarp = iwarp.movedim(0, -1)
    iaff = mni_affine.inverse() @ dat_affine
    iwarp = linalg.matvec(iaff[:3, :3], iwarp)

    if (options.warp_mni or options.all_mni) and options.warp:
        fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.mni     ->', fname)
        io.savef(iwarp,
                 fname,
                 like=ref_native,
                 affine=mni_affine,
                 dtype='float32')

    # --- Warped space -------------------------------------------------
    iwarp = spatial.add_identity_grid_(iwarp)
    iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp)

    if options.prob_wrp or options.labels_wrp or options.all_wrp:
        z_mni = spatial.grid_pull(z, iwarp)
        if options.prob_mni or options.all_wrp:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.wrp     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni or options.all_wrp:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.wrp   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_wrp or options.nobias_wrp
                         or options.all_wrp):
        bias = spatial.grid_pull(prm['bias'],
                                 iwarp,
                                 interpolation=3,
                                 prefilter=False,
                                 bound='dct2')
        if options.bias_wrp or options.all_wrp:
            fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.wrp     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.wrp.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_wrp or options.all_wrp:
            nobias = spatial.grid_pull(dat, iwarp)
            nobias *= bias
            fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.wrp   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.wrp.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias