예제 #1
0
    def exp(self, velocity, displacement=False):
        """Generate a deformation grid from tangent parameters.

        Parameters
        ----------
        velocity : (batch, *spatial, nb_dim)
            Stationary velocity field
        displacement : bool, default=False
            Return a displacement field (voxel to shift) rather than
            a transformation field (voxel to voxel).

        Returns
        -------
        grid : (batch, *spatial, nb_dim)
            Deformation grid (transformation or displacement).

        """
        # generate grid
        shape = velocity.shape[1:-1]
        velocity_small = self.resize(velocity)
        grid = self.velexp(velocity_small)
        grid = self.resize(grid, output_shape=shape, factor=None)
        if not displacement:
            grid = spatial.add_identity_grid_(grid)
        return grid
예제 #2
0
    def forward(self,
                source,
                target,
                source_seg=None,
                target_seg=None,
                *,
                _loss=None,
                _metric=None):

        vel = self.unet(torch.cat([source, target], dim=1))
        if hasattr(self, 'resize_vel'):
            vel = self.resize_vel(vel)
        grid = self.exp(vel)
        if hasattr(self, 'resize_grid'):
            grid = self.resize_grid(vel, output_shape=source.shape[2:])
        grid = spatial.add_identity_grid_(grid)
        deformed_source = self.pull(source, grid)

        if source_seg is not None:
            deformed_source_seg = self.pull(source_seg, grid)
        else:
            deformed_source_seg = None

        # compute loss and metrics
        self.compute(_loss,
                     _metric,
                     image=[deformed_source, target],
                     velocity=[vel],
                     segmentation=[deformed_source_seg, target_seg])

        if source_seg is None:
            return deformed_source, vel, grid
        else:
            return deformed_source, deformed_source_seg, vel, grid
예제 #3
0
def derivatives_distortion(contrast,
                           distortion,
                           intercept,
                           decay,
                           opt,
                           do_grad=True):
    """Compute the gradient and Hessian of the distortion field.

    Parameters
    ----------
    contrast : (nb_echo, *obs_shape) GradientEchoMulti
        A single echo series (with the same weighting)
    distortion : ParameterizedDeformation
        A model of distortions caused by B0 inhomogeneities.
    intercept : (*recon_shape) ParameterMap
        Log-intercept of the contrast
    decay : (*recon_shape) ParameterMap
        Exponential decay
    opt : Options

    Returns
    -------
    crit : () tensor
        Log-likelihood
    grad : (*shape, 3) tensor
    hess : (*shape, 6) tensor

    """

    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    obs_shape = contrast.volume.shape[1:]
    recon_shape = intercept.volume.shape
    aff = core.linalg.lmdiv(intercept.affine, contrast.affine)
    aff = aff.to(**backend)
    lam = 1 / contrast.noise
    df = contrast.dof
    chi = opt.likelihood[0].lower() == 'c'

    # pull parameter maps to observed space
    grid = smart_grid(aff, obs_shape, recon_shape)
    inter = smart_pull(intercept.fdata(**backend), grid)
    slope = smart_pull(decay.fdata(**backend), grid)
    readout = contrast.readout
    if opt.distortion.te_scaling != 'pre':
        grid_up, grid_down = distortion.exp2(
            add_identity=not opt.distortion.te_scaling)
    else:
        grid_up = grid_down = None

    crit = 0
    grad = torch.zeros(obs_shape + (3, ), **backend) if do_grad else None
    hess = torch.zeros(obs_shape + (6, ), **backend) if do_grad else None

    te0 = 0
    for e, echo in enumerate(contrast):

        te = echo.te
        te0 = te0 or te
        blip = echo.blip or (2 * (e % 2) - 1)
        grid_blip = grid_up if blip > 0 else grid_down
        vscl = te / te0
        if opt.distortion.te_scaling == 'pre':
            vexp = distortion.iexp if blip < 0 else distortion.exp
            grid_blip = vexp(add_identity=True, alpha=vscl)
        elif opt.distortion.te_scaling:
            grid_blip = spatial.add_identity_grid_(vscl * grid_blip)

        # compute residuals
        dat = echo.fdata(**backend, rand=True, cache=False)  # observed
        fit = recon_fit(inter, slope, te)  # fitted
        if do_grad and isinstance(distortion, DenseDeformation):
            # D(fit) o phi
            gfit = smart_grad(fit, grid_blip, bound='dft', extrapolate=True)
        fit = smart_pull(fit, grid_blip, bound='dft', extrapolate=True)
        msk = get_mask_missing(dat, fit)  # mask of missing values
        if do_grad and isinstance(distortion, SVFDeformation):
            # D(fit o phi)
            gfit = spatial.diff(fit, bound='dft', dim=[-3, -2, -1])
            gfit.masked_fill_(msk.unsqueeze(-1), 0)
        dat.masked_fill_(msk, 0)
        fit.masked_fill_(msk, 0)
        msk = msk.bitwise_not_()

        if chi:
            crit1, res = nll_chi(dat, fit, msk, lam, df)
        else:
            crit1, res = nll_gauss(dat, fit, msk, lam)
        del dat, fit, msk
        crit += crit1

        if do_grad:
            g1 = res.unsqueeze(-1).mul(gfit)
            h1 = torch.zeros_like(hess)
            if readout is None:
                h1[..., :3] = gfit.square()
                h1[..., 3] = gfit[..., 0] * gfit[..., 1]
                h1[..., 4] = gfit[..., 0] * gfit[..., 2]
                h1[..., 5] = gfit[..., 1] * gfit[..., 2]
            else:
                h1[..., readout] = gfit[..., readout].square()

            # propagate backward
            if isinstance(distortion, SVFDeformation):
                vel = distortion.volume
                if opt.distortion.te_scaling == 'pre':
                    vel = ((-vscl) * vel) if blip < 0 else (vscl * vel)
                elif blip < 0:
                    vel = -vel
                g1, h1 = spatial.exp_backward(vel,
                                              g1,
                                              h1,
                                              steps=distortion.steps)

            alpha_g = alpha_h = lam
            alpha_g = alpha_g * blip
            if opt.distortion.te_scaling == 'pre':
                alpha_g = alpha_g * vscl
                alpha_h = alpha_h * (vscl * vscl)
            grad.add_(g1, alpha=alpha_g)
            hess.add_(h1, alpha=alpha_h)

    if not do_grad:
        return crit

    if readout is None:
        mask_nan_(grad)
        mask_nan_(hess[:-3], 1e-3)  # diagonal
        mask_nan_(hess[-3:])  # off-diagonal
    else:
        grad = grad[..., readout]
        hess = hess[..., readout]
        mask_nan_(grad)
        mask_nan_(hess)

    return crit, grad, hess
예제 #4
0
def derivatives_parameters(contrast,
                           distortion,
                           intercept,
                           decay,
                           opt,
                           do_grad=True):
    """Compute the gradient and Hessian of the parameter maps with
    respect to one contrast.

    Parameters
    ----------
    contrast : (nb_echo, *obs_shape) GradientEchoMulti
        A single echo series (with the same weighting)
    distortion : ParameterizedDeformation
        A model of distortions caused by B0 inhomogeneities.
    intercept : (*recon_shape) ParameterMap
        Log-intercept of the contrast
    decay : (*recon_shape) ParameterMap
        Exponential decay
    opt : Options
    do_grad : bool, default=True

    Returns
    -------
    crit : () tensor
        Log-likelihood
    grad : (2, *recon_shape) tensor, if `do_grad`
        Gradient with respect to:
            [0] intercept
            [1] decay
    hess : (3, *recon_shape) tensor, if `do_grad`
        Hessian with respect to:
            [0] intercept ** 2
            [1] decay ** 2
            [2] intercept * decay

    """

    dtype = opt.backend.dtype
    device = opt.backend.device
    backend = dict(dtype=dtype, device=device)

    obs_shape = contrast.volume.shape[1:]
    recon_shape = intercept.volume.shape
    aff = core.linalg.lmdiv(intercept.affine, contrast.affine)
    aff = aff.to(**backend)
    lam = 1 / contrast.noise
    df = contrast.dof
    chi = opt.likelihood[0].lower() == 'c'

    # pull parameter maps to observed space
    grid = smart_grid(aff, obs_shape, recon_shape)
    inter = smart_pull(intercept.fdata(**backend), grid)
    slope = smart_pull(decay.fdata(**backend), grid)
    if distortion and opt.distortion.te_scaling != 'pre':
        grid_up, grid_down = distortion.exp2(
            add_identity=not opt.distortion.te_scaling)
    else:
        grid_up = grid_down = None

    crit = 0
    grad = torch.zeros((2, ) + obs_shape, **backend) if do_grad else None
    hess = torch.zeros((3, ) + obs_shape, **backend) if do_grad else None

    te0 = 0
    for e, echo in enumerate(contrast):

        te = echo.te
        te0 = te0 or te
        blip = echo.blip or (2 * (e % 2) - 1)
        grid_blip = grid_up if blip > 0 else grid_down
        if distortion:
            vscl = te / te0
            if opt.distortion.te_scaling == 'pre':
                vexp = distortion.iexp if blip < 0 else distortion.exp
                grid_blip = vexp(add_identity=True, alpha=vscl)
            elif opt.distortion.te_scaling == 'post':
                grid_blip = spatial.add_identity_grid_(vscl * grid_blip)

        # compute residuals
        dat = echo.fdata(**backend, rand=True, cache=False)
        fit = recon_fit(inter, slope, te)
        pull_fit = smart_pull(fit, grid_blip, bound='dft', extrapolate=True)
        msk = get_mask_missing(dat, pull_fit)
        dat.masked_fill_(msk, 0)
        pull_fit.masked_fill_(msk, 0)
        msk = msk.bitwise_not_()

        if chi:
            crit1, res = nll_chi(dat, pull_fit, msk, lam, df)
        else:
            crit1, res = nll_gauss(dat, pull_fit, msk, lam)
        del dat, pull_fit
        crit += crit1

        if do_grad:
            msk = msk.to(fit.dtype)
            if grid_blip is not None:
                res0 = res
                res = smart_push(res0,
                                 grid_blip,
                                 bound='dft',
                                 extrapolate=True)
                abs_res = smart_push(res0.abs_(),
                                     grid_blip,
                                     bound='dft',
                                     extrapolate=True)
                abs_res.mul_(fit)
                msk = smart_push(msk, grid_blip, bound='dft', extrapolate=True)
                del res0

            # ----------------------------------------------------------
            # compute gradient and (majorised) Hessian in observed space
            #
            #   grad[inter]       =           lam * fit * res
            #   grad[decay]       =     -te * lam * fit * res
            #   hess[inter**2]    =           lam * fit * (fit + abs(res))
            #   hess[decay**2]    = (te*te) * lam * fit * (fit + abs(res))
            #   hess[inter*decay] =     -te * lam * fit * fit
            #
            # I tried to put that into an "accumulation" function but it
            # does super weird stuff, so I keep it in the main loop. I am
            # saving a few allocations here so I think it's faster than
            # torchscript.
            # ----------------------------------------------------------

            res.mul_(fit)
            grad[0].add_(res, alpha=lam)
            grad[1].add_(res, alpha=-te * lam)
            if grid_blip is None:
                abs_res = res.abs_()
            fit2 = fit.mul_(fit).mul_(msk)
            del msk
            hess[2].add_(fit2, alpha=-te * lam)
            fit2.add_(abs_res)
            hess[0].add_(fit2, alpha=lam)
            hess[1].add_(fit2, alpha=lam * (te * te))

            del res, fit, abs_res, fit2

    if not do_grad:
        return crit

    mask_nan_(grad)
    mask_nan_(hess[:-1], 1e-3)  # diagonal
    mask_nan_(hess[-1])  # off-diagonal

    # push gradient and Hessian to recon space
    grad = smart_push(grad, grid, recon_shape)
    hess = smart_push(hess, grid, recon_shape)
    return crit, grad, hess
예제 #5
0
파일: param.py 프로젝트: balbasty/nitorch
 def add_identity_(self, disp):
     disp = utils.movedim(disp, self.displacement_dim, -1)
     disp = spatial.add_identity_grid_(disp.unsqueeze(-1)).squeeze(-1)
     disp = utils.movedim(disp, -1, self.displacement_dim)
     return disp
예제 #6
0
파일: main.py 프로젝트: balbasty/nitorch
def main_apply(options):
    """
    Unwarp distorted images using a pre-computed 1d displacement field.
    """
    device = get_device(options.gpu)

    # detect readout direction
    if options.file_pos:
        f0 = io.map(options.file_pos[0])
    else:
        f0 = io.map(options.file_neg[0])
    dim = f0.affine.shape[-1] - 1
    readout = get_readout(options.readout, f0.affine, f0.shape[-dim:])

    def do_apply(fnames, phi, jac):
        """Correct files with a given polarity"""
        for fname in fnames:
            dir, base, ext = py.fileparts(fname)
            ofname = options.output
            ofname = ofname.format(dir=dir or '.', sep=os.sep, base=base,
                                   ext=ext)
            if options.verbose:
                print(f'unwarp {fname} \n'
                      f'    -> {ofname}')

            f = io.map(fname)
            d = f.fdata(device=device)
            d = utils.movedim(d, readout, -1)
            d = _deform1d(d, phi)
            if jac is not None:
                d *= jac
            d = utils.movedim(d, -1, readout)

            io.savef(d, ofname, like=fname)

    # load and apply
    vel = io.loadf(options.dist_file, device=device)
    vel = utils.movedim(vel, readout, -1)

    if options.file_pos:
        if options.diffeo:
            phi, *jac = spatial.exp1d_forward(vel, bound='dct2',
                                              jacobian=options.modulation)
            jac = jac[0] if jac else None
        else:
            phi = vel.clone()
            jac = None
            if options.modulation:
                jac = spatial.diff1d(phi, dim=readout, bound='dct2', side='c')
                jac += 1
        phi = spatial.add_identity_grid_(phi.unsqueeze(-1)).squeeze(-1)

        do_apply(options.file_pos, phi, jac)

    if options.file_neg:
        if options.diffeo:
            phi, *jac = spatial.exp1d_forward(-vel, bound='dct2',
                                              jacobian=options.modulation)
            jac = jac[0] if jac else None
        else:
            phi = -vel
            jac = None
            if options.modulation:
                jac = spatial.diff1d(phi, dim=readout, bound='dct2', side='c')
                jac += 1
        phi = spatial.add_identity_grid_(phi.unsqueeze(-1)).squeeze(-1)

        do_apply(options.file_neg, phi, jac)
예제 #7
0
파일: cli.py 프로젝트: balbasty/nitorch
def write_outputs(z, prm, options):

    # prepare filenames
    ref_native = options.input[0]
    ref_mni = options.tpm[0] if options.tpm else path_spm_prior()
    format_dict = get_format_dict(ref_native, options.output)

    # move channels to back
    backend = utils.backend(z)
    if (options.nobias_nat or options.nobias_mni or options.nobias_wrp
            or options.all_nat or options.all_mni or options.all_wrp):
        dat, _, affine = get_data(options.input, options.mask, None, 3,
                                  **backend)

    # --- native space -------------------------------------------------

    if options.prob_nat or options.all_nat:
        fname = options.prob_nat or '{dir}{sep}{base}.prob.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('prob.nat     ->', fname)
        io.savef(torch.movedim(z, 0, -1),
                 fname,
                 like=ref_native,
                 dtype='float32')

    if options.labels_nat or options.all_nat:
        fname = options.labels_nat or '{dir}{sep}{base}.labels.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('labels.nat   ->', fname)
        io.save(z.argmax(0), fname, like=ref_native, dtype='int16')

    if (options.bias_nat or options.all_nat) and options.bias:
        bias = prm['bias']
        fname = options.bias_nat or '{dir}{sep}{base}.bias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('bias.nat     ->', fname)
            io.savef(torch.movedim(bias, 0, -1),
                     fname,
                     like=ref_native,
                     dtype='float32')
        else:
            for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'bias.nat.{c+1}   ->', fname)
                io.savef(bias1, fname, like=ref1, dtype='float32')
        del bias

    if (options.nobias_nat or options.all_nat) and options.bias:
        nobias = dat * prm['bias']
        fname = options.nobias_nat or '{dir}{sep}{base}.nobias.nat{ext}'
        if len(options.input) == 1:
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('nobias.nat   ->', fname)
            io.savef(torch.movedim(nobias, 0, -1), fname, like=ref_native)
        else:
            for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                format_dict1 = get_format_dict(ref1, options.output)
                fname = fname.format(**format_dict1)
                if options.verbose > 0:
                    print(f'nobias.nat.{c+1} ->', fname)
                io.savef(nobias1, fname, like=ref1)
        del nobias

    if (options.warp_nat or options.all_nat) and options.warp:
        warp = prm['warp']
        fname = options.warp_nat or '{dir}{sep}{base}.warp.nat{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.nat     ->', fname)
        io.savef(warp, fname, like=ref_native, dtype='float32')

    # --- MNI space ----------------------------------------------------
    if options.tpm is False:
        # No template -> no MNI space
        return

    fref = io.map(ref_mni)
    mni_affine, mni_shape = fref.affine, fref.shape[:3]
    dat_affine = io.map(ref_native).affine
    mni_affine = mni_affine.to(**backend)
    dat_affine = dat_affine.to(**backend)
    prm_affine = prm['affine'].to(**backend)
    dat_affine = prm_affine @ dat_affine
    if options.mni_vx:
        vx = spatial.voxel_size(mni_affine)
        scl = vx / options.mni_vx
        mni_affine, mni_shape = spatial.affine_resize(mni_affine,
                                                      mni_shape,
                                                      scl,
                                                      anchor='f')

    if options.prob_mni or options.labels_mni or options.all_mni:
        z_mni = spatial.reslice(z, dat_affine, mni_affine, mni_shape)
        if options.prob_mni:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.mni     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.mni{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.mni   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_mni or options.nobias_mni
                         or options.all_mni):
        bias = spatial.reslice(prm['bias'],
                               dat_affine,
                               mni_affine,
                               mni_shape,
                               interpolation=3,
                               prefilter=False,
                               bound='dct2')

        if options.bias_mni or options.all_mni:
            fname = options.bias_mni or '{dir}{sep}{base}.bias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.mni     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.mni.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_mni or options.all_mni:
            nobias = spatial.reslice(dat, dat_affine, mni_affine, mni_shape)
            nobias *= bias
            fname = options.bias_mni or '{dir}{sep}{base}.nobias.mni{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.mni   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.mni.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias

    need_iwarp = (options.warp_mni or options.prob_wrp or options.labels_wrp
                  or options.bias_wrp or options.nobias_wrp or options.all_mni
                  or options.all_wrp)
    need_iwarp = need_iwarp and options.warp
    if not need_iwarp:
        return

    iwarp = spatial.grid_inv(prm['warp'], type='disp')
    iwarp = iwarp.movedim(-1, 0)
    iwarp = spatial.reslice(iwarp,
                            dat_affine,
                            mni_affine,
                            mni_shape,
                            interpolation=2,
                            bound='dft',
                            extrapolate=True)
    iwarp = iwarp.movedim(0, -1)
    iaff = mni_affine.inverse() @ dat_affine
    iwarp = linalg.matvec(iaff[:3, :3], iwarp)

    if (options.warp_mni or options.all_mni) and options.warp:
        fname = options.warp_mni or '{dir}{sep}{base}.warp.mni{ext}'
        fname = fname.format(**format_dict)
        if options.verbose > 0:
            print('warp.mni     ->', fname)
        io.savef(iwarp,
                 fname,
                 like=ref_native,
                 affine=mni_affine,
                 dtype='float32')

    # --- Warped space -------------------------------------------------
    iwarp = spatial.add_identity_grid_(iwarp)
    iwarp = spatial.affine_matvec(dat_affine.inverse() @ mni_affine, iwarp)

    if options.prob_wrp or options.labels_wrp or options.all_wrp:
        z_mni = spatial.grid_pull(z, iwarp)
        if options.prob_mni or options.all_wrp:
            fname = options.prob_mni or '{dir}{sep}{base}.prob.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('prob.wrp     ->', fname)
            io.savef(torch.movedim(z_mni, 0, -1),
                     fname,
                     like=ref_native,
                     affine=mni_affine,
                     dtype='float32')
        if options.labels_mni or options.all_wrp:
            fname = options.labels_mni or '{dir}{sep}{base}.labels.wrp{ext}'
            fname = fname.format(**format_dict)
            if options.verbose > 0:
                print('labels.wrp   ->', fname)
            io.save(z_mni.argmax(0),
                    fname,
                    like=ref_native,
                    affine=mni_affine,
                    dtype='int16')
        del z_mni

    if options.bias and (options.bias_wrp or options.nobias_wrp
                         or options.all_wrp):
        bias = spatial.grid_pull(prm['bias'],
                                 iwarp,
                                 interpolation=3,
                                 prefilter=False,
                                 bound='dct2')
        if options.bias_wrp or options.all_wrp:
            fname = options.bias_wrp or '{dir}{sep}{base}.bias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('bias.wrp     ->', fname)
                io.savef(torch.movedim(bias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine,
                         dtype='float32')
            else:
                for c, (bias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'bias.wrp.{c+1}   ->', fname)
                    io.savef(bias1,
                             fname,
                             like=ref1,
                             affine=mni_affine,
                             dtype='float32')

        if options.nobias_wrp or options.all_wrp:
            nobias = spatial.grid_pull(dat, iwarp)
            nobias *= bias
            fname = options.nobias_wrp or '{dir}{sep}{base}.nobias.wrp{ext}'
            if len(options.input) == 1:
                fname = fname.format(**format_dict)
                if options.verbose > 0:
                    print('nobias.wrp   ->', fname)
                io.savef(torch.movedim(nobias, 0, -1),
                         fname,
                         like=ref_native,
                         affine=mni_affine)
            else:
                for c, (nobias1, ref1) in enumerate(zip(bias, options.input)):
                    format_dict1 = get_format_dict(ref1, options.output)
                    fname = fname.format(**format_dict1)
                    if options.verbose > 0:
                        print(f'nobias.wrp.{c+1} ->', fname)
                    io.savef(nobias1, fname, like=ref1, affine=mni_affine)
            del nobias

        del bias