Example #1
0
def init_rodent_brain_extraction_wf(
    ants_affine_init=False,
    factor=20,
    arc=0.12,
    step=4,
    grid=(0, 4, 4),
    debug=False,
    interim_checkpoints=True,
    mem_gb=3.0,
    mri_scheme="T2w",
    name="rodent_brain_extraction_wf",
    omp_nthreads=None,
    output_dir=None,
    template_id="Fischer344",
    template_specs=None,
    use_float=True,
):
    """
    Build an atlas-based brain extraction pipeline for rodent T1w and T2w MRI data.

    Parameters
    ----------
    ants_affine_init : :obj:`bool`, optional
        Set-up a pre-initialization step with ``antsAI`` to account for mis-oriented images.

    """
    inputnode = pe.Node(niu.IdentityInterface(fields=["in_files", "in_mask"]),
                        name="inputnode")
    outputnode = pe.Node(
        niu.IdentityInterface(
            fields=["out_corrected", "out_brain", "out_mask"]),
        name="outputnode",
    )

    template_specs = template_specs or {}
    if template_id == "WHS" and "resolution" not in template_specs:
        template_specs["resolution"] = 2

    # Find a suitable target template in TemplateFlow
    tpl_target_path = get_template(
        template_id,
        suffix=mri_scheme,
        **template_specs,
    )
    if not tpl_target_path:
        raise RuntimeError(
            f"An instance of template <tpl-{template_id}> with MR scheme '{mri_scheme}'"
            " could not be found.")

    tpl_brainmask_path = get_template(
        template_id,
        atlas=None,
        hemi=None,
        desc="brain",
        suffix="probseg",
        **template_specs,
    ) or get_template(
        template_id,
        atlas=None,
        hemi=None,
        desc="brain",
        suffix="mask",
        **template_specs,
    )

    tpl_regmask_path = get_template(
        template_id,
        atlas=None,
        desc="BrainCerebellumExtraction",
        suffix="mask",
        **template_specs,
    )

    denoise = pe.Node(DenoiseImage(dimension=3, copy_header=True),
                      name="denoise",
                      n_procs=omp_nthreads)

    # Resample template to a controlled, isotropic resolution
    res_tmpl = pe.Node(RegridToZooms(zooms=HIRES_ZOOMS, smooth=True),
                       name="res_tmpl")

    # Create Laplacian images
    lap_tmpl = pe.Node(ImageMath(operation="Laplacian", copy_header=True),
                       name="lap_tmpl")
    tmpl_sigma = pe.Node(niu.Function(function=_lap_sigma),
                         name="tmpl_sigma",
                         run_without_submitting=True)
    norm_lap_tmpl = pe.Node(niu.Function(function=_norm_lap),
                            name="norm_lap_tmpl")

    lap_target = pe.Node(ImageMath(operation="Laplacian", copy_header=True),
                         name="lap_target")
    target_sigma = pe.Node(niu.Function(function=_lap_sigma),
                           name="target_sigma",
                           run_without_submitting=True)
    norm_lap_target = pe.Node(niu.Function(function=_norm_lap),
                              name="norm_lap_target")

    # Set up initial spatial normalization
    ants_params = "testing" if debug else "precise"
    norm = pe.Node(
        Registration(from_file=pkgr_fn(
            "nirodents",
            f"data/artsBrainExtraction_{ants_params}_{mri_scheme}.json")),
        name="norm",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )
    norm.inputs.float = use_float

    # main workflow
    wf = pe.Workflow(name)

    # truncate target intensity for N4 correction
    clip_target = pe.Node(IntensityClip(p_min=15, p_max=99.9),
                          name="clip_target")

    # truncate template intensity to match target
    clip_tmpl = pe.Node(IntensityClip(p_min=5, p_max=98), name="clip_tmpl")
    clip_tmpl.inputs.in_file = _pop(tpl_target_path)

    # set INU bspline grid based on voxel size
    bspline_grid = pe.Node(niu.Function(function=_bspline_grid),
                           name="bspline_grid")

    # INU correction of the target image
    init_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=False,
            copy_header=True,
            n_iterations=[50] * (4 - debug),
            convergence_threshold=1e-7,
            shrink_factor=4,
            rescale_intensities=True,
        ),
        n_procs=omp_nthreads,
        name="init_n4",
    )
    clip_inu = pe.Node(IntensityClip(p_min=1, p_max=99.8), name="clip_inu")

    # Create a buffer interface as a cache for the actual inputs to registration
    buffernode = pe.Node(niu.IdentityInterface(fields=["hires_target"]),
                         name="buffernode")

    # Merge image nodes
    mrg_target = pe.Node(niu.Merge(2), name="mrg_target")
    mrg_tmpl = pe.Node(niu.Merge(2), name="mrg_tmpl")

    # fmt: off
    wf.connect([
        # Target image massaging
        (inputnode, denoise, [(("in_files", _pop), "input_image")]),
        (inputnode, bspline_grid, [(("in_files", _pop), "in_file")]),
        (bspline_grid, init_n4, [("out", "args")]),
        (denoise, clip_target, [("output_image", "in_file")]),
        (clip_target, init_n4, [("out_file", "input_image")]),
        (init_n4, clip_inu, [("output_image", "in_file")]),
        (clip_inu, target_sigma, [("out_file", "in_file")]),
        (clip_inu, buffernode, [("out_file", "hires_target")]),
        (buffernode, lap_target, [("hires_target", "op1")]),
        (target_sigma, lap_target, [("out", "op2")]),
        (lap_target, norm_lap_target, [("output_image", "in_file")]),
        (buffernode, mrg_target, [("hires_target", "in1")]),
        (norm_lap_target, mrg_target, [("out", "in2")]),
        # Template massaging
        (clip_tmpl, res_tmpl, [("out_file", "in_file")]),
        (res_tmpl, tmpl_sigma, [("out_file", "in_file")]),
        (res_tmpl, lap_tmpl, [("out_file", "op1")]),
        (tmpl_sigma, lap_tmpl, [("out", "op2")]),
        (lap_tmpl, norm_lap_tmpl, [("output_image", "in_file")]),
        (res_tmpl, mrg_tmpl, [("out_file", "in1")]),
        (norm_lap_tmpl, mrg_tmpl, [("out", "in2")]),
        # Setup inputs to spatial normalization
        (mrg_target, norm, [("out", "moving_image")]),
        (mrg_tmpl, norm, [("out", "fixed_image")]),
    ])
    # fmt: on

    # Graft a template registration-mask if present
    if tpl_regmask_path:
        hires_mask = pe.Node(
            ApplyTransforms(
                input_image=_pop(tpl_regmask_path),
                transforms="identity",
                interpolation="Gaussian",
                float=True,
            ),
            name="hires_mask",
            mem_gb=1,
        )

        # fmt: off
        wf.connect([
            (res_tmpl, hires_mask, [("out_file", "reference_image")]),
            (hires_mask, norm, [("output_image", "fixed_image_masks")]),
        ])
        # fmt: on

    # Finally project brain mask and refine INU correction
    map_brainmask = pe.Node(
        ApplyTransforms(interpolation="Gaussian", float=True),
        name="map_brainmask",
        mem_gb=1,
    )
    map_brainmask.inputs.input_image = str(tpl_brainmask_path)

    thr_brainmask = pe.Node(Binarize(thresh_low=0.50), name="thr_brainmask")

    final_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=True,
            copy_header=True,
            n_iterations=[50] * 4,
            convergence_threshold=1e-7,
            rescale_intensities=True,
            shrink_factor=4,
        ),
        n_procs=omp_nthreads,
        name="final_n4",
    )
    final_mask = pe.Node(ApplyMask(), name="final_mask")

    # fmt: off
    wf.connect([
        (inputnode, map_brainmask, [(("in_files", _pop), "reference_image")]),
        (bspline_grid, final_n4, [("out", "args")]),
        (denoise, final_n4, [("output_image", "input_image")]),
        # Project template's brainmask into subject space
        (norm, map_brainmask, [("reverse_transforms", "transforms"),
                               ("reverse_invert_flags",
                                "invert_transform_flags")]),
        (map_brainmask, thr_brainmask, [("output_image", "in_file")]),
        # take a second pass of N4
        (map_brainmask, final_n4, [("output_image", "mask_image")]),
        (final_n4, final_mask, [("output_image", "in_file")]),
        (thr_brainmask, final_mask, [("out_mask", "in_mask")]),
        (final_n4, outputnode, [("output_image", "out_corrected")]),
        (thr_brainmask, outputnode, [("out_mask", "out_mask")]),
        (final_mask, outputnode, [("out_file", "out_brain")]),
    ])
    # fmt: on

    if interim_checkpoints:
        final_apply = pe.Node(
            ApplyTransforms(interpolation="BSpline", float=True),
            name="final_apply",
            mem_gb=1,
        )
        final_report = pe.Node(
            SimpleBeforeAfter(after_label="target",
                              before_label=f"tpl-{template_id}"),
            name="final_report",
        )
        # fmt: off
        wf.connect([
            (inputnode, final_apply, [(("in_files", _pop), "reference_image")
                                      ]),
            (res_tmpl, final_apply, [("out_file", "input_image")]),
            (norm, final_apply, [("reverse_transforms", "transforms"),
                                 ("reverse_invert_flags",
                                  "invert_transform_flags")]),
            (final_apply, final_report, [("output_image", "before")]),
            (outputnode, final_report, [("out_corrected", "after"),
                                        ("out_mask", "wm_seg")]),
        ])
        # fmt: on

    if ants_affine_init:
        # Initialize transforms with antsAI
        lowres_tmpl = pe.Node(RegridToZooms(zooms=LOWRES_ZOOMS, smooth=True),
                              name="lowres_tmpl")
        lowres_trgt = pe.Node(RegridToZooms(zooms=LOWRES_ZOOMS, smooth=True),
                              name="lowres_trgt")

        init_aff = pe.Node(
            AI(
                convergence=(100, 1e-6, 10),
                metric=("Mattes", 32, "Random", 0.25),
                principal_axes=False,
                search_factor=(factor, arc),
                search_grid=(step, grid),
                transform=("Affine", 0.1),
                verbose=True,
            ),
            name="init_aff",
            n_procs=omp_nthreads,
        )
        # fmt: off
        wf.connect([
            (clip_inu, lowres_trgt, [("out_file", "in_file")]),
            (lowres_trgt, init_aff, [("out_file", "moving_image")]),
            (clip_tmpl, lowres_tmpl, [("out_file", "in_file")]),
            (lowres_tmpl, init_aff, [("out_file", "fixed_image")]),
            (init_aff, norm, [("output_transform", "initial_moving_transform")
                              ]),
        ])
        # fmt: on

        if tpl_regmask_path:
            lowres_mask = pe.Node(
                ApplyTransforms(
                    input_image=_pop(tpl_regmask_path),
                    transforms="identity",
                    interpolation="MultiLabel",
                ),
                name="lowres_mask",
                mem_gb=1,
            )
            # fmt: off
            wf.connect([
                (lowres_tmpl, lowres_mask, [("out_file", "reference_image")]),
                (lowres_mask, init_aff, [("output_image", "fixed_image_mask")
                                         ]),
            ])
            # fmt: on

        if interim_checkpoints:
            init_apply = pe.Node(
                ApplyTransforms(interpolation="BSpline",
                                invert_transform_flags=[True]),
                name="init_apply",
                mem_gb=1,
            )
            init_mask = pe.Node(
                ApplyTransforms(interpolation="Gaussian",
                                invert_transform_flags=[True]),
                name="init_mask",
                mem_gb=1,
            )
            init_mask.inputs.input_image = str(tpl_brainmask_path)
            init_report = pe.Node(
                SimpleBeforeAfter(
                    out_report="init_report.svg",
                    before_label="target",
                    after_label="template",
                ),
                name="init_report",
            )
            # fmt: off
            wf.connect([
                (lowres_trgt, init_apply, [("out_file", "reference_image")]),
                (lowres_tmpl, init_apply, [("out_file", "input_image")]),
                (init_aff, init_apply, [("output_transform", "transforms")]),
                (lowres_trgt, init_report, [("out_file", "before")]),
                (init_apply, init_report, [("output_image", "after")]),
                (lowres_trgt, init_mask, [("out_file", "reference_image")]),
                (init_aff, init_mask, [("output_transform", "transforms")]),
                (init_mask, init_report, [("output_image", "wm_seg")]),
            ])
            # fmt: on
    else:
        norm.inputs.initial_moving_transform_com = 1

    if output_dir:
        ds_final_inu = pe.Node(DerivativesDataSink(
            base_directory=str(output_dir),
            desc="preproc",
            compress=True,
        ),
                               name="ds_final_inu",
                               run_without_submitting=True)
        ds_final_msk = pe.Node(DerivativesDataSink(
            base_directory=str(output_dir),
            desc="brain",
            suffix="mask",
            compress=True,
        ),
                               name="ds_final_msk",
                               run_without_submitting=True)

        # fmt: off
        wf.connect([
            (inputnode, ds_final_inu, [("in_files", "source_file")]),
            (inputnode, ds_final_msk, [("in_files", "source_file")]),
            (outputnode, ds_final_inu, [("out_corrected", "in_file")]),
            (outputnode, ds_final_msk, [("out_mask", "in_file")]),
        ])
        # fmt: on

        if interim_checkpoints:
            ds_report = pe.Node(DerivativesDataSink(
                base_directory=str(output_dir),
                desc="brain",
                suffix="mask",
                datatype="figures"),
                                name="ds_report",
                                run_without_submitting=True)
            # fmt: off
            wf.connect([
                (inputnode, ds_report, [("in_files", "source_file")]),
                (final_report, ds_report, [("out_report", "in_file")]),
            ])
            # fmt: on

        if ants_affine_init and interim_checkpoints:
            ds_report_init = pe.Node(DerivativesDataSink(
                base_directory=str(output_dir),
                desc="init",
                suffix="mask",
                datatype="figures"),
                                     name="ds_report_init",
                                     run_without_submitting=True)
            # fmt: off
            wf.connect([
                (inputnode, ds_report_init, [("in_files", "source_file")]),
                (init_report, ds_report_init, [("out_report", "in_file")]),
            ])
            # fmt: on

    return wf
Example #2
0
def init_syn_sdc_wf(
    *,
    atlas_threshold=3,
    sloppy=False,
    debug=False,
    name="syn_sdc_wf",
    omp_nthreads=1,
):
    """
    Build the *fieldmap-less* susceptibility-distortion estimation workflow.

    SyN deformation is restricted to the phase-encoding (PE) direction.
    If no PE direction is specified, anterior-posterior PE is assumed.

    SyN deformation is also restricted to regions that are expected to have a
    >3mm (approximately 1 voxel) warp, based on the fieldmap atlas.


    Workflow Graph
        .. workflow ::
            :graph2use: orig
            :simple_form: yes

            from sdcflows.workflows.fit.syn import init_syn_sdc_wf
            wf = init_syn_sdc_wf(omp_nthreads=8)

    Parameters
    ----------
    atlas_threshold : :obj:`float`
        Exclude from the registration metric computation areas with average distortions
        below this threshold (in mm).
    sloppy : :obj:`bool`
        Whether a fast (less accurate) configuration of the workflow should be applied.
    debug : :obj:`bool`
        Run in debug mode
    name : :obj:`str`
        Name for this workflow
    omp_nthreads : :obj:`int`
        Parallelize internal tasks across the number of CPUs given by this option.

    Inputs
    ------
    epi_ref : :obj:`tuple` (:obj:`str`, :obj:`dict`)
        A tuple, where the first element is the path of the distorted EPI
        reference map (e.g., an average of *b=0* volumes), and the second
        element is a dictionary of associated metadata.
    epi_mask : :obj:`str`
        A path to a brain mask corresponding to ``epi_ref``.
    anat_ref : :obj:`str`
        A preprocessed, skull-stripped anatomical (T1w or T2w) image resampled in EPI space.
    anat_mask : :obj:`str`
        Path to the brain mask corresponding to ``anat_ref`` in EPI space.
    sd_prior : :obj:`str`
        A template map of areas with strong susceptibility distortions (SD) to regularize
        the cost function of SyN

    Outputs
    -------
    fmap : :obj:`str`
        The path of the estimated fieldmap.
    fmap_ref : :obj:`str`
        The path of an unwarped conversion of files in ``epi_ref``.
    fmap_coeff : :obj:`str` or :obj:`list` of :obj:`str`
        The path(s) of the B-Spline coefficients supporting the fieldmap.
    out_warp : :obj:`str`
        The path of the corresponding displacements field transform to unwarp
        susceptibility distortions.
    method: :obj:`str`
        Short description of the estimation method that was run.

    """
    from pkg_resources import resource_filename as pkgrf
    from packaging.version import parse as parseversion, Version
    from nipype.interfaces.ants import ImageMath
    from niworkflows.interfaces.fixes import (
        FixHeaderApplyTransforms as ApplyTransforms,
        FixHeaderRegistration as Registration,
    )
    from niworkflows.interfaces.nibabel import (
        Binarize,
        IntensityClip,
        RegridToZooms,
    )
    from ...utils.misc import front as _pop, last as _pull
    from ...interfaces.epi import GetReadoutTime
    from ...interfaces.fmap import DisplacementsField2Fieldmap
    from ...interfaces.bspline import (
        ApplyCoeffsField,
        BSplineApprox,
        DEFAULT_LF_ZOOMS_MM,
        DEFAULT_HF_ZOOMS_MM,
        DEFAULT_ZOOMS_MM,
    )
    from ...interfaces.brainmask import BinaryDilation, Union

    ants_version = Registration().version
    if ants_version and parseversion(ants_version) < Version("2.2.0"):
        raise RuntimeError(
            f"Please upgrade ANTs to 2.2 or older ({ants_version} found)."
        )

    workflow = Workflow(name=name)
    workflow.__desc__ = f"""\
A deformation field to correct for susceptibility distortions was estimated
based on *fMRIPrep*'s *fieldmap-less* approach.
The deformation field is that resulting from co-registering the EPI reference
to the same-subject T1w-reference with its intensity inverted [@fieldmapless1;
@fieldmapless2].
Registration is performed with `antsRegistration`
(ANTs {ants_version or "-- version unknown"}), and
the process regularized by constraining deformation to be nonzero only
along the phase-encoding direction, and modulated with an average fieldmap
template [@fieldmapless3].
"""
    inputnode = pe.Node(niu.IdentityInterface(INPUT_FIELDS), name="inputnode")
    outputnode = pe.Node(
        niu.IdentityInterface(
            ["fmap", "fmap_ref", "fmap_coeff", "fmap_mask", "out_warp", "method"]
        ),
        name="outputnode",
    )
    outputnode.inputs.method = 'FLB ("fieldmap-less", SyN-based)'

    readout_time = pe.Node(
        GetReadoutTime(),
        name="readout_time",
        run_without_submitting=True,
    )

    warp_dir = pe.Node(
        niu.Function(function=_warp_dir),
        run_without_submitting=True,
        name="warp_dir",
    )
    warp_dir.inputs.nlevels = 2
    atlas_msk = pe.Node(Binarize(thresh_low=atlas_threshold), name="atlas_msk")
    anat_dilmsk = pe.Node(BinaryDilation(), name="anat_dilmsk")
    amask2epi = pe.Node(
        ApplyTransforms(interpolation="MultiLabel", transforms="identity"),
        name="amask2epi",
    )

    # Calculate laplacian maps
    lap_anat = pe.Node(
        ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True), name="lap_anat"
    )
    lap_anat_norm = pe.Node(niu.Function(function=_norm_lap), name="lap_anat_norm")
    anat_merge = pe.Node(
        niu.Merge(2),
        name="anat_merge",
        run_without_submitting=True,
    )

    clip_epi = pe.Node(IntensityClip(p_min=35.0, p_max=99.9), name="clip_epi")
    lap_epi = pe.Node(
        ImageMath(operation="Laplacian", op2="1.5 1", copy_header=True), name="lap_epi"
    )
    lap_epi_norm = pe.Node(niu.Function(function=_norm_lap), name="lap_epi_norm")
    epi_merge = pe.Node(
        niu.Merge(2),
        name="epi_merge",
        run_without_submitting=True,
    )

    epi_umask = pe.Node(Union(), name="epi_umask")
    moving_masks = pe.Node(
        niu.Merge(3),
        name="moving_masks",
        run_without_submitting=True,
    )

    fixed_masks = pe.Node(
        niu.Merge(3),
        name="fixed_masks",
        mem_gb=DEFAULT_MEMORY_MIN_GB,
        run_without_submitting=True,
    )

    # Set a manageable size for the epi reference
    find_zooms = pe.Node(niu.Function(function=_adjust_zooms), name="find_zooms")
    zooms_epi = pe.Node(RegridToZooms(), name="zooms_epi")

    # SyN Registration Core
    syn = pe.Node(
        Registration(
            from_file=pkgrf("sdcflows", f"data/sd_syn{'_sloppy' * sloppy}.json")
        ),
        name="syn",
        n_procs=omp_nthreads,
    )
    syn.inputs.output_warped_image = debug
    syn.inputs.output_inverse_warped_image = debug

    if debug:
        syn.inputs.args = "--write-interval-volumes 2"

    # Extract the corresponding fieldmap in Hz
    extract_field = pe.Node(
        DisplacementsField2Fieldmap(demean=True), name="extract_field"
    )

    unwarp = pe.Node(ApplyCoeffsField(), name="unwarp")

    # Check zooms (avoid very expensive B-Splines fitting)
    zooms_field = pe.Node(
        ApplyTransforms(
            interpolation="BSpline", transforms="identity", args="-u float"
        ),
        name="zooms_field",
    )
    zooms_bmask = pe.Node(
        ApplyTransforms(
            interpolation="MultiLabel", transforms="identity", args="-u uchar"
        ),
        name="zooms_bmask",
    )

    # Regularize with B-Splines
    bs_filter = pe.Node(BSplineApprox(), n_procs=omp_nthreads, name="bs_filter")
    bs_filter.interface._always_run = debug
    bs_filter.inputs.bs_spacing = (
        [DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM] if not sloppy else [DEFAULT_ZOOMS_MM]
    )
    bs_filter.inputs.extrapolate = not debug

    # fmt: off
    workflow.connect([
        (inputnode, readout_time, [(("epi_ref", _pop), "in_file"),
                                   (("epi_ref", _pull), "metadata")]),
        (inputnode, atlas_msk, [("sd_prior", "in_file")]),
        (inputnode, clip_epi, [(("epi_ref", _pop), "in_file")]),
        (inputnode, unwarp, [(("epi_ref", _pop), "in_data")]),
        (inputnode, amask2epi, [("epi_mask", "reference_image")]),
        (inputnode, zooms_bmask, [("anat_mask", "input_image")]),
        (inputnode, fixed_masks, [("anat_mask", "in1"),
                                  ("anat_mask", "in2")]),
        (inputnode, anat_dilmsk, [("anat_mask", "in_file")]),
        (inputnode, warp_dir, [("anat_ref", "fixed_image")]),
        (inputnode, anat_merge, [("anat_ref", "in1")]),
        (inputnode, lap_anat, [("anat_ref", "op1")]),
        (inputnode, find_zooms, [("anat_ref", "in_anat"),
                                 (("epi_ref", _pop), "in_epi")]),
        (inputnode, zooms_field, [(("epi_ref", _pop), "reference_image")]),
        (inputnode, epi_umask, [("epi_mask", "in1")]),
        (lap_anat, lap_anat_norm, [("output_image", "in_file")]),
        (lap_anat_norm, anat_merge, [("out", "in2")]),
        (epi_umask, moving_masks, [("out_file", "in1"),
                                   ("out_file", "in2"),
                                   ("out_file", "in3")]),
        (clip_epi, epi_merge, [("out_file", "in1")]),
        (clip_epi, lap_epi, [("out_file", "op1")]),
        (clip_epi, zooms_epi, [("out_file", "in_file")]),
        (lap_epi, lap_epi_norm, [("output_image", "in_file")]),
        (lap_epi_norm, epi_merge, [("out", "in2")]),
        (find_zooms, zooms_epi, [("out", "zooms")]),
        (atlas_msk, fixed_masks, [("out_mask", "in3")]),
        (anat_dilmsk, amask2epi, [("out_file", "input_image")]),
        (amask2epi, epi_umask, [("output_image", "in2")]),
        (readout_time, warp_dir, [("pe_direction", "pe_dir")]),
        (warp_dir, syn, [("out", "restrict_deformation")]),
        (anat_merge, syn, [("out", "fixed_image")]),
        (fixed_masks, syn, [("out", "fixed_image_masks")]),
        (epi_merge, syn, [("out", "moving_image")]),
        (moving_masks, syn, [("out", "moving_image_masks")]),
        (syn, extract_field, [(("forward_transforms", _pop), "transform")]),
        (readout_time, extract_field, [("readout_time", "ro_time"),
                                       ("pe_direction", "pe_dir")]),
        (extract_field, zooms_field, [("out_file", "input_image")]),
        (zooms_field, zooms_bmask, [("output_image", "reference_image")]),
        (zooms_field, bs_filter, [("output_image", "in_data")]),
        # Setting a mask ends up over-fitting the field
        # - it's better to have all those ~zero around.
        # (zooms_bmask, bs_filter, [("output_image", "in_mask")]),
        (bs_filter, unwarp, [("out_coeff", "in_coeff")]),
        (readout_time, unwarp, [("readout_time", "ro_time"),
                                ("pe_direction", "pe_dir")]),
        (zooms_bmask, outputnode, [("output_image", "fmap_mask")]),
        (bs_filter, outputnode, [("out_coeff", "fmap_coeff")]),
        (unwarp, outputnode, [("out_corrected", "fmap_ref"),
                              ("out_field", "fmap"),
                              ("out_warp", "out_warp")]),
    ])
    # fmt: on

    return workflow
Example #3
0
def init_bold_confs_wf(
    mem_gb,
    metadata,
    regressors_all_comps,
    regressors_dvars_th,
    regressors_fd_th,
    freesurfer=False,
    name="bold_confs_wf",
):
    """
    Build a workflow to generate and write out confounding signals.

    This workflow calculates confounds for a BOLD series, and aggregates them
    into a :abbr:`TSV (tab-separated value)` file, for use as nuisance
    regressors in a :abbr:`GLM (general linear model)`.
    The following confounds are calculated, with column headings in parentheses:

    #. Region-wise average signal (``csf``, ``white_matter``, ``global_signal``)
    #. DVARS - original and standardized variants (``dvars``, ``std_dvars``)
    #. Framewise displacement, based on head-motion parameters
       (``framewise_displacement``)
    #. Temporal CompCor (``t_comp_cor_XX``)
    #. Anatomical CompCor (``a_comp_cor_XX``)
    #. Cosine basis set for high-pass filtering w/ 0.008 Hz cut-off
       (``cosine_XX``)
    #. Non-steady-state volumes (``non_steady_state_XX``)
    #. Estimated head-motion parameters, in mm and rad
       (``trans_x``, ``trans_y``, ``trans_z``, ``rot_x``, ``rot_y``, ``rot_z``)


    Prior to estimating aCompCor and tCompCor, non-steady-state volumes are
    censored and high-pass filtered using a :abbr:`DCT (discrete cosine
    transform)` basis.
    The cosine basis, as well as one regressor per censored volume, are included
    for convenience.

    Workflow Graph
        .. workflow::
            :graph2use: orig
            :simple_form: yes

            from fmriprep.workflows.bold.confounds import init_bold_confs_wf
            wf = init_bold_confs_wf(
                mem_gb=1,
                metadata={},
                regressors_all_comps=False,
                regressors_dvars_th=1.5,
                regressors_fd_th=0.5,
            )

    Parameters
    ----------
    mem_gb : :obj:`float`
        Size of BOLD file in GB - please note that this size
        should be calculated after resamplings that may extend
        the FoV
    metadata : :obj:`dict`
        BIDS metadata for BOLD file
    name : :obj:`str`
        Name of workflow (default: ``bold_confs_wf``)
    regressors_all_comps : :obj:`bool`
        Indicates whether CompCor decompositions should return all
        components instead of the minimal number of components necessary
        to explain 50 percent of the variance in the decomposition mask.
    regressors_dvars_th : :obj:`float`
        Criterion for flagging DVARS outliers
    regressors_fd_th : :obj:`float`
        Criterion for flagging framewise displacement outliers

    Inputs
    ------
    bold
        BOLD image, after the prescribed corrections (STC, HMC and SDC)
        when available.
    bold_mask
        BOLD series mask
    movpar_file
        SPM-formatted motion parameters file
    rmsd_file
        Framewise displacement as measured by ``fsl_motion_outliers``.
    skip_vols
        number of non steady state volumes
    t1w_mask
        Mask of the skull-stripped template image
    t1w_tpms
        List of tissue probability maps in T1w space
    t1_bold_xform
        Affine matrix that maps the T1w space into alignment with
        the native BOLD space

    Outputs
    -------
    confounds_file
        TSV of all aggregated confounds
    rois_report
        Reportlet visualizing white-matter/CSF mask used for aCompCor,
        the ROI for tCompCor and the BOLD brain mask.
    confounds_metadata
        Confounds metadata dictionary.

    """
    from niworkflows.engine.workflows import LiterateWorkflow as Workflow
    from niworkflows.interfaces.confounds import ExpandModel, SpikeRegressors
    from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
    from niworkflows.interfaces.images import SignalExtraction
    from niworkflows.interfaces.masks import ROIsPlot
    from niworkflows.interfaces.nibabel import ApplyMask, Binarize
    from niworkflows.interfaces.patches import (
        RobustACompCor as ACompCor,
        RobustTCompCor as TCompCor,
    )
    from niworkflows.interfaces.plotting import (CompCorVariancePlot,
                                                 ConfoundsCorrelationPlot)
    from niworkflows.interfaces.utils import (AddTSVHeader, TSV2JSON,
                                              DictMerge)
    from ...interfaces.confounds import aCompCorMasks

    gm_desc = (
        "dilating a GM mask extracted from the FreeSurfer's *aseg* segmentation"
        if freesurfer else
        "thresholding the corresponding partial volume map at 0.05")

    workflow = Workflow(name=name)
    workflow.__desc__ = f"""\
Several confounding time-series were calculated based on the
*preprocessed BOLD*: framewise displacement (FD), DVARS and
three region-wise global signals.
FD was computed using two formulations following Power (absolute sum of
relative motions, @power_fd_dvars) and Jenkinson (relative root mean square
displacement between affines, @mcflirt).
FD and DVARS are calculated for each functional run, both using their
implementations in *Nipype* [following the definitions by @power_fd_dvars].
The three global signals are extracted within the CSF, the WM, and
the whole-brain masks.
Additionally, a set of physiological regressors were extracted to
allow for component-based noise correction [*CompCor*, @compcor].
Principal components are estimated after high-pass filtering the
*preprocessed BOLD* time-series (using a discrete cosine filter with
128s cut-off) for the two *CompCor* variants: temporal (tCompCor)
and anatomical (aCompCor).
tCompCor components are then calculated from the top 2% variable
voxels within the brain mask.
For aCompCor, three probabilistic masks (CSF, WM and combined CSF+WM)
are generated in anatomical space.
The implementation differs from that of Behzadi et al. in that instead
of eroding the masks by 2 pixels on BOLD space, the aCompCor masks are
subtracted a mask of pixels that likely contain a volume fraction of GM.
This mask is obtained by {gm_desc}, and it ensures components are not extracted
from voxels containing a minimal fraction of GM.
Finally, these masks are resampled into BOLD space and binarized by
thresholding at 0.99 (as in the original implementation).
Components are also calculated separately within the WM and CSF masks.
For each CompCor decomposition, the *k* components with the largest singular
values are retained, such that the retained components' time series are
sufficient to explain 50 percent of variance across the nuisance mask (CSF,
WM, combined, or temporal). The remaining components are dropped from
consideration.
The head-motion estimates calculated in the correction step were also
placed within the corresponding confounds file.
The confound time series derived from head motion estimates and global
signals were expanded with the inclusion of temporal derivatives and
quadratic terms for each [@confounds_satterthwaite_2013].
Frames that exceeded a threshold of {regressors_fd_th} mm FD or
{regressors_dvars_th} standardised DVARS were annotated as motion outliers.
"""
    inputnode = pe.Node(niu.IdentityInterface(fields=[
        'bold', 'bold_mask', 'movpar_file', 'rmsd_file', 'skip_vols',
        't1w_mask', 't1w_tpms', 't1_bold_xform'
    ]),
                        name='inputnode')
    outputnode = pe.Node(niu.IdentityInterface(fields=[
        'confounds_file', 'confounds_metadata', 'acompcor_masks',
        'tcompcor_mask'
    ]),
                         name='outputnode')

    # DVARS
    dvars = pe.Node(nac.ComputeDVARS(save_nstd=True,
                                     save_std=True,
                                     remove_zerovariance=True),
                    name="dvars",
                    mem_gb=mem_gb)

    # Frame displacement
    fdisp = pe.Node(nac.FramewiseDisplacement(parameter_source="SPM"),
                    name="fdisp",
                    mem_gb=mem_gb)

    # Generate aCompCor probseg maps
    acc_masks = pe.Node(aCompCorMasks(is_aseg=freesurfer), name="acc_masks")

    # Resample probseg maps in BOLD space via T1w-to-BOLD transform
    acc_msk_tfm = pe.MapNode(ApplyTransforms(interpolation='Gaussian',
                                             float=False),
                             iterfield=["input_image"],
                             name='acc_msk_tfm',
                             mem_gb=0.1)
    acc_msk_brain = pe.MapNode(ApplyMask(),
                               name="acc_msk_brain",
                               iterfield=["in_file"])
    acc_msk_bin = pe.MapNode(Binarize(thresh_low=0.99),
                             name='acc_msk_bin',
                             iterfield=["in_file"])
    acompcor = pe.Node(ACompCor(components_file='acompcor.tsv',
                                header_prefix='a_comp_cor_',
                                pre_filter='cosine',
                                save_pre_filter=True,
                                save_metadata=True,
                                mask_names=['CSF', 'WM', 'combined'],
                                merge_method='none',
                                failure_mode='NaN'),
                       name="acompcor",
                       mem_gb=mem_gb)

    tcompcor = pe.Node(TCompCor(components_file='tcompcor.tsv',
                                header_prefix='t_comp_cor_',
                                pre_filter='cosine',
                                save_pre_filter=True,
                                save_metadata=True,
                                percentile_threshold=.02,
                                failure_mode='NaN'),
                       name="tcompcor",
                       mem_gb=mem_gb)

    # Set number of components
    if regressors_all_comps:
        acompcor.inputs.num_components = 'all'
        tcompcor.inputs.num_components = 'all'
    else:
        acompcor.inputs.variance_threshold = 0.5
        tcompcor.inputs.variance_threshold = 0.5

    # Set TR if present
    if 'RepetitionTime' in metadata:
        tcompcor.inputs.repetition_time = metadata['RepetitionTime']
        acompcor.inputs.repetition_time = metadata['RepetitionTime']

    # Global and segment regressors
    signals_class_labels = [
        "global_signal",
        "csf",
        "white_matter",
        "csf_wm",
        "tcompcor",
    ]
    merge_rois = pe.Node(niu.Merge(3, ravel_inputs=True),
                         name='merge_rois',
                         run_without_submitting=True)
    signals = pe.Node(SignalExtraction(class_labels=signals_class_labels),
                      name="signals",
                      mem_gb=mem_gb)

    # Arrange confounds
    add_dvars_header = pe.Node(AddTSVHeader(columns=["dvars"]),
                               name="add_dvars_header",
                               mem_gb=0.01,
                               run_without_submitting=True)
    add_std_dvars_header = pe.Node(AddTSVHeader(columns=["std_dvars"]),
                                   name="add_std_dvars_header",
                                   mem_gb=0.01,
                                   run_without_submitting=True)
    add_motion_headers = pe.Node(AddTSVHeader(
        columns=["trans_x", "trans_y", "trans_z", "rot_x", "rot_y", "rot_z"]),
                                 name="add_motion_headers",
                                 mem_gb=0.01,
                                 run_without_submitting=True)
    add_rmsd_header = pe.Node(AddTSVHeader(columns=["rmsd"]),
                              name="add_rmsd_header",
                              mem_gb=0.01,
                              run_without_submitting=True)
    concat = pe.Node(GatherConfounds(),
                     name="concat",
                     mem_gb=0.01,
                     run_without_submitting=True)

    # CompCor metadata
    tcc_metadata_fmt = pe.Node(TSV2JSON(
        index_column='component',
        drop_columns=['mask'],
        output=None,
        additional_metadata={'Method': 'tCompCor'},
        enforce_case=True),
                               name='tcc_metadata_fmt')
    acc_metadata_fmt = pe.Node(TSV2JSON(
        index_column='component',
        output=None,
        additional_metadata={'Method': 'aCompCor'},
        enforce_case=True),
                               name='acc_metadata_fmt')
    mrg_conf_metadata = pe.Node(niu.Merge(3),
                                name='merge_confound_metadata',
                                run_without_submitting=True)
    mrg_conf_metadata.inputs.in3 = {
        label: {
            'Method': 'Mean'
        }
        for label in signals_class_labels
    }
    mrg_conf_metadata2 = pe.Node(DictMerge(),
                                 name='merge_confound_metadata2',
                                 run_without_submitting=True)

    # Expand model to include derivatives and quadratics
    model_expand = pe.Node(
        ExpandModel(model_formula='(dd1(rps + wm + csf + gsr))^^2 + others'),
        name='model_expansion')

    # Add spike regressors
    spike_regress = pe.Node(SpikeRegressors(fd_thresh=regressors_fd_th,
                                            dvars_thresh=regressors_dvars_th),
                            name='spike_regressors')

    # Generate reportlet (ROIs)
    mrg_compcor = pe.Node(niu.Merge(2, ravel_inputs=True),
                          name='mrg_compcor',
                          run_without_submitting=True)
    rois_plot = pe.Node(ROIsPlot(colors=['b', 'magenta'],
                                 generate_report=True),
                        name='rois_plot',
                        mem_gb=mem_gb)

    ds_report_bold_rois = pe.Node(DerivativesDataSink(
        desc='rois', datatype="figures", dismiss_entities=("echo", )),
                                  name='ds_report_bold_rois',
                                  run_without_submitting=True,
                                  mem_gb=DEFAULT_MEMORY_MIN_GB)

    # Generate reportlet (CompCor)
    mrg_cc_metadata = pe.Node(niu.Merge(2),
                              name='merge_compcor_metadata',
                              run_without_submitting=True)
    compcor_plot = pe.Node(CompCorVariancePlot(
        variance_thresholds=(0.5, 0.7, 0.9),
        metadata_sources=['tCompCor', 'aCompCor']),
                           name='compcor_plot')
    ds_report_compcor = pe.Node(DerivativesDataSink(
        desc='compcorvar', datatype="figures", dismiss_entities=("echo", )),
                                name='ds_report_compcor',
                                run_without_submitting=True,
                                mem_gb=DEFAULT_MEMORY_MIN_GB)

    # Generate reportlet (Confound correlation)
    conf_corr_plot = pe.Node(ConfoundsCorrelationPlot(
        reference_column='global_signal', max_dim=20),
                             name='conf_corr_plot')
    ds_report_conf_corr = pe.Node(DerivativesDataSink(
        desc='confoundcorr', datatype="figures", dismiss_entities=("echo", )),
                                  name='ds_report_conf_corr',
                                  run_without_submitting=True,
                                  mem_gb=DEFAULT_MEMORY_MIN_GB)

    def _last(inlist):
        return inlist[-1]

    def _select_cols(table):
        import pandas as pd
        return [
            col for col in pd.read_table(table, nrows=2).columns
            if not col.startswith(("a_comp_cor_", "t_comp_cor_", "std_dvars"))
        ]

    workflow.connect([
        # connect inputnode to each non-anatomical confound node
        (inputnode, dvars, [('bold', 'in_file'), ('bold_mask', 'in_mask')]),
        (inputnode, fdisp, [('movpar_file', 'in_file')]),

        # aCompCor
        (inputnode, acompcor, [("bold", "realigned_file"),
                               ("skip_vols", "ignore_initial_volumes")]),
        (inputnode, acc_masks, [("t1w_tpms", "in_vfs"),
                                (("bold", _get_zooms), "bold_zooms")]),
        (inputnode, acc_msk_tfm, [("t1_bold_xform", "transforms"),
                                  ("bold_mask", "reference_image")]),
        (inputnode, acc_msk_brain, [("bold_mask", "in_mask")]),
        (acc_masks, acc_msk_tfm, [("out_masks", "input_image")]),
        (acc_msk_tfm, acc_msk_brain, [("output_image", "in_file")]),
        (acc_msk_brain, acc_msk_bin, [("out_file", "in_file")]),
        (acc_msk_bin, acompcor, [("out_file", "mask_files")]),

        # tCompCor
        (inputnode, tcompcor, [("bold", "realigned_file"),
                               ("skip_vols", "ignore_initial_volumes"),
                               ("bold_mask", "mask_files")]),
        # Global signals extraction (constrained by anatomy)
        (inputnode, signals, [('bold', 'in_file')]),
        (inputnode, merge_rois, [('bold_mask', 'in1')]),
        (acc_msk_bin, merge_rois, [('out_file', 'in2')]),
        (tcompcor, merge_rois, [('high_variance_masks', 'in3')]),
        (merge_rois, signals, [('out', 'label_files')]),

        # Collate computed confounds together
        (inputnode, add_motion_headers, [('movpar_file', 'in_file')]),
        (inputnode, add_rmsd_header, [('rmsd_file', 'in_file')]),
        (dvars, add_dvars_header, [('out_nstd', 'in_file')]),
        (dvars, add_std_dvars_header, [('out_std', 'in_file')]),
        (signals, concat, [('out_file', 'signals')]),
        (fdisp, concat, [('out_file', 'fd')]),
        (tcompcor, concat, [('components_file', 'tcompcor'),
                            ('pre_filter_file', 'cos_basis')]),
        (acompcor, concat, [('components_file', 'acompcor')]),
        (add_motion_headers, concat, [('out_file', 'motion')]),
        (add_rmsd_header, concat, [('out_file', 'rmsd')]),
        (add_dvars_header, concat, [('out_file', 'dvars')]),
        (add_std_dvars_header, concat, [('out_file', 'std_dvars')]),

        # Confounds metadata
        (tcompcor, tcc_metadata_fmt, [('metadata_file', 'in_file')]),
        (acompcor, acc_metadata_fmt, [('metadata_file', 'in_file')]),
        (tcc_metadata_fmt, mrg_conf_metadata, [('output', 'in1')]),
        (acc_metadata_fmt, mrg_conf_metadata, [('output', 'in2')]),
        (mrg_conf_metadata, mrg_conf_metadata2, [('out', 'in_dicts')]),

        # Expand the model with derivatives, quadratics, and spikes
        (concat, model_expand, [('confounds_file', 'confounds_file')]),
        (model_expand, spike_regress, [('confounds_file', 'confounds_file')]),

        # Set outputs
        (spike_regress, outputnode, [('confounds_file', 'confounds_file')]),
        (mrg_conf_metadata2, outputnode, [('out_dict', 'confounds_metadata')]),
        (tcompcor, outputnode, [("high_variance_masks", "tcompcor_mask")]),
        (acc_msk_bin, outputnode, [("out_file", "acompcor_masks")]),
        (inputnode, rois_plot, [('bold', 'in_file'),
                                ('bold_mask', 'in_mask')]),
        (tcompcor, mrg_compcor, [('high_variance_masks', 'in1')]),
        (acc_msk_bin, mrg_compcor, [(('out_file', _last), 'in2')]),
        (mrg_compcor, rois_plot, [('out', 'in_rois')]),
        (rois_plot, ds_report_bold_rois, [('out_report', 'in_file')]),
        (tcompcor, mrg_cc_metadata, [('metadata_file', 'in1')]),
        (acompcor, mrg_cc_metadata, [('metadata_file', 'in2')]),
        (mrg_cc_metadata, compcor_plot, [('out', 'metadata_files')]),
        (compcor_plot, ds_report_compcor, [('out_file', 'in_file')]),
        (concat, conf_corr_plot, [('confounds_file', 'confounds_file'),
                                  (('confounds_file', _select_cols), 'columns')
                                  ]),
        (conf_corr_plot, ds_report_conf_corr, [('out_file', 'in_file')]),
    ])

    return workflow
Example #4
0
def init_infant_brain_extraction_wf(
    age_months=None,
    ants_affine_init=False,
    bspline_fitting_distance=200,
    sloppy=False,
    skull_strip_template="UNCInfant",
    template_specs=None,
    mem_gb=3.0,
    debug=False,
    name="infant_brain_extraction_wf",
    omp_nthreads=None,
):
    """
    Build an atlas-based brain extraction pipeline for infant T2w MRI data.

    Pros/Cons of available templates
    --------------------------------
    * MNIInfant
     + More cohorts available for finer-grain control
     + T1w/T2w images available
     - Template masks are poor

    * UNCInfant
     + Accurate masks
     - No T2w image available


    Parameters
    ----------
    age_months : :obj:`int`
        Age of this participant, in months.
    ants_affine_init : :obj:`bool`, optional
        Set-up a pre-initialization step with ``antsAI`` to account for mis-oriented images.
    bspline_fitting_distance : :obj:`float`
        Distance in mm between B-Spline control points for N4 INU estimation.
    sloppy : :obj:`bool`
        Run in *sloppy* mode.
    skull_strip_template : :obj:`str`
        A TemplateFlow ID indicating which template will be used as target for atlas-based
        segmentation.
    template_specs : :obj:`dict`
        Additional template specifications (e.g., resolution or cohort) to correctly select
        the adequate template instance.
    mem_gb : :obj:`float`
        Base memory fingerprint unit.
    name : :obj:`str`
        This particular workflow's unique name (Nipype requirement).
    omp_nthreads : :obj:`int`
        The number of threads for individual processes in this workflow.
    debug : :obj:`bool`
        Produce intermediate registration files

    Inputs
    ------
    in_t2w : :obj:`str`
        The unprocessed input T2w image.

    Outputs
    -------
    t2w_preproc : :obj:`str`
        The preprocessed T2w image (INU and clipping).
    t2w_brain : :obj:`str`
        The preprocessed, brain-extracted T2w image.
    out_mask : :obj:`str`
        The brainmask projected from the template into the T2w, after
        binarization.
    out_probmap : :obj:`str`
        The same as above, before binarization.

    """
    from nipype.interfaces.ants import N4BiasFieldCorrection, ImageMath

    # niworkflows
    from niworkflows.interfaces.nibabel import ApplyMask, Binarize, IntensityClip
    from niworkflows.interfaces.fixes import (
        FixHeaderRegistration as Registration,
        FixHeaderApplyTransforms as ApplyTransforms,
    )
    from templateflow.api import get as get_template

    from ...interfaces.nibabel import BinaryDilation
    from ...utils.misc import cohort_by_months

    # handle template specifics
    template_specs = template_specs or {}
    if skull_strip_template == "MNIInfant":
        template_specs["resolution"] = 2 if sloppy else 1

    if not template_specs.get("cohort"):
        if age_months is None:
            raise KeyError(
                f"Age or cohort for {skull_strip_template} must be provided!")
        template_specs["cohort"] = cohort_by_months(skull_strip_template,
                                                    age_months)

    tpl_target_path = get_template(
        skull_strip_template,
        suffix="T1w",  # no T2w template
        desc=None,
        **template_specs,
    )
    if not tpl_target_path:
        raise RuntimeError(
            f"An instance of template <tpl-{skull_strip_template}> with T1w suffix "
            "could not be found.")

    tpl_brainmask_path = get_template(skull_strip_template,
                                      label="brain",
                                      suffix="probseg",
                                      **template_specs) or get_template(
                                          skull_strip_template,
                                          desc="brain",
                                          suffix="mask",
                                          **template_specs)

    tpl_regmask_path = get_template(
        skull_strip_template,
        label="BrainCerebellumExtraction",
        suffix="mask",
        **template_specs,
    )

    # main workflow
    workflow = pe.Workflow(name)

    inputnode = pe.Node(niu.IdentityInterface(fields=["in_t2w"]),
                        name="inputnode")
    outputnode = pe.Node(
        niu.IdentityInterface(
            fields=["t2w_preproc", "t2w_brain", "out_mask", "out_probmap"]),
        name="outputnode",
    )

    # Ensure template comes with a range of intensities ANTs will like
    clip_tmpl = pe.Node(IntensityClip(p_max=99), name="clip_tmpl")
    clip_tmpl.inputs.in_file = _pop(tpl_target_path)

    # Generate laplacian registration targets
    lap_tmpl = pe.Node(ImageMath(operation="Laplacian", op2="0.4 1"),
                       name="lap_tmpl")
    lap_t2w = pe.Node(ImageMath(operation="Laplacian", op2="0.4 1"),
                      name="lap_t2w")
    norm_lap_tmpl = pe.Node(niu.Function(function=_norm_lap),
                            name="norm_lap_tmpl")
    norm_lap_t2w = pe.Node(niu.Function(function=_norm_lap),
                           name="norm_lap_t2w")

    # Merge image nodes
    mrg_tmpl = pe.Node(niu.Merge(2),
                       name="mrg_tmpl",
                       run_without_submitting=True)
    mrg_t2w = pe.Node(niu.Merge(2),
                      name="mrg_t2w",
                      run_without_submitting=True)
    bin_regmask = pe.Node(Binarize(thresh_low=0.20), name="bin_regmask")
    bin_regmask.inputs.in_file = str(tpl_brainmask_path)
    refine_mask = pe.Node(BinaryDilation(radius=3, iterations=2),
                          name="refine_mask")

    fixed_masks = pe.Node(niu.Merge(4),
                          name="fixed_masks",
                          run_without_submitting=True)
    fixed_masks.inputs.in1 = "NULL"
    fixed_masks.inputs.in2 = "NULL"
    fixed_masks.inputs.in3 = "NULL" if not tpl_regmask_path else _pop(
        tpl_regmask_path)

    # Set up initial spatial normalization
    ants_params = "testing" if sloppy else "precise"
    norm = pe.Node(
        Registration(from_file=pkgr_fn(
            "nibabies.data", f"antsBrainExtraction_{ants_params}.json")),
        name="norm",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )
    norm.inputs.float = sloppy
    if debug:
        norm.inputs.args = "--write-interval-volumes 5"

    map_mask_t2w = pe.Node(
        ApplyTransforms(interpolation="Gaussian", float=True),
        name="map_mask_t2w",
        mem_gb=1,
    )

    # map template brainmask to t2w space
    map_mask_t2w.inputs.input_image = str(tpl_brainmask_path)

    thr_t2w_mask = pe.Node(Binarize(thresh_low=0.80), name="thr_t2w_mask")

    # Refine INU correction
    final_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            bspline_fitting_distance=bspline_fitting_distance,
            save_bias=True,
            copy_header=True,
            n_iterations=[50] * 5,
            convergence_threshold=1e-7,
            rescale_intensities=True,
            shrink_factor=4,
        ),
        n_procs=omp_nthreads,
        name="final_n4",
    )
    final_clip = pe.Node(IntensityClip(p_min=5.0, p_max=99.5),
                         name="final_clip")
    apply_mask = pe.Node(ApplyMask(), name="apply_mask")

    # fmt:off
    workflow.connect([
        (inputnode, final_n4, [("in_t2w", "input_image")]),
        # 1. Massage T2w
        (inputnode, mrg_t2w, [("in_t2w", "in1")]),
        (inputnode, lap_t2w, [("in_t2w", "op1")]),
        (inputnode, map_mask_t2w, [("in_t2w", "reference_image")]),
        (bin_regmask, refine_mask, [("out_file", "in_file")]),
        (refine_mask, fixed_masks, [("out_file", "in4")]),
        (lap_t2w, norm_lap_t2w, [("output_image", "in_file")]),
        (norm_lap_t2w, mrg_t2w, [("out", "in2")]),
        # 2. Prepare template
        (clip_tmpl, lap_tmpl, [("out_file", "op1")]),
        (lap_tmpl, norm_lap_tmpl, [("output_image", "in_file")]),
        (clip_tmpl, mrg_tmpl, [("out_file", "in1")]),
        (norm_lap_tmpl, mrg_tmpl, [("out", "in2")]),
        # 3. Set normalization node inputs
        (mrg_tmpl, norm, [("out", "fixed_image")]),
        (mrg_t2w, norm, [("out", "moving_image")]),
        (fixed_masks, norm, [("out", "fixed_image_masks")]),
        # 4. Map template brainmask into T2w space
        (norm, map_mask_t2w, [("reverse_transforms", "transforms"),
                              ("reverse_invert_flags",
                               "invert_transform_flags")]),
        (map_mask_t2w, thr_t2w_mask, [("output_image", "in_file")]),
        (thr_t2w_mask, apply_mask, [("out_mask", "in_mask")]),
        (final_n4, apply_mask, [("output_image", "in_file")]),
        # 5. Refine T2w INU correction with brain mask
        (map_mask_t2w, final_n4, [("output_image", "weight_image")]),
        (final_n4, final_clip, [("output_image", "in_file")]),
        # 9. Outputs
        (final_clip, outputnode, [("out_file", "t2w_preproc")]),
        (map_mask_t2w, outputnode, [("output_image", "out_probmap")]),
        (thr_t2w_mask, outputnode, [("out_mask", "out_mask")]),
        (apply_mask, outputnode, [("out_file", "t2w_brain")]),
    ])
    # fmt:on

    if ants_affine_init:
        from nipype.interfaces.ants.utils import AI

        ants_kwargs = dict(
            metric=("Mattes", 32, "Regular", 0.2),
            transform=("Affine", 0.1),
            search_factor=(20, 0.12),
            principal_axes=False,
            convergence=(10, 1e-6, 10),
            search_grid=(40, (0, 40, 40)),
            verbose=True,
        )

        if ants_affine_init == "random":
            ants_kwargs["metric"] = ("Mattes", 32, "Random", 0.2)
        if ants_affine_init == "search":
            ants_kwargs["search_grid"] = (20, (20, 40, 40))

        init_aff = pe.Node(
            AI(**ants_kwargs),
            name="init_aff",
            n_procs=omp_nthreads,
        )
        if tpl_regmask_path:
            init_aff.inputs.fixed_image_mask = _pop(tpl_regmask_path)

        # fmt:off
        workflow.connect([
            (clip_tmpl, init_aff, [("out_file", "fixed_image")]),
            (inputnode, init_aff, [("in_t2w", "moving_image")]),
            (init_aff, norm, [("output_transform", "initial_moving_transform")
                              ]),
        ])
        # fmt:on

    return workflow
Example #5
0
def init_syn_sdc_wf(
    *,
    atlas_threshold=3,
    debug=False,
    name="syn_sdc_wf",
    omp_nthreads=1,
):
    """
    Build the *fieldmap-less* susceptibility-distortion estimation workflow.

    SyN deformation is restricted to the phase-encoding (PE) direction.
    If no PE direction is specified, anterior-posterior PE is assumed.

    SyN deformation is also restricted to regions that are expected to have a
    >3mm (approximately 1 voxel) warp, based on the fieldmap atlas.


    Workflow Graph
        .. workflow ::
            :graph2use: orig
            :simple_form: yes

            from sdcflows.workflows.fit.syn import init_syn_sdc_wf
            wf = init_syn_sdc_wf(omp_nthreads=8)

    Parameters
    ----------
    atlas_threshold : :obj:`float`
        Exclude from the registration metric computation areas with average distortions
        below this threshold (in mm).
    debug : :obj:`bool`
        Whether a fast (less accurate) configuration of the workflow should be applied.
    name : :obj:`str`
        Name for this workflow
    omp_nthreads : :obj:`int`
        Parallelize internal tasks across the number of CPUs given by this option.

    Inputs
    ------
    epi_ref : :obj:`tuple` (:obj:`str`, :obj:`dict`)
        A tuple, where the first element is the path of the distorted EPI
        reference map (e.g., an average of *b=0* volumes), and the second
        element is a dictionary of associated metadata.
    epi_mask : :obj:`str`
        A path to a brain mask corresponding to ``epi_ref``.
    anat_brain : :obj:`str`
        A preprocessed, skull-stripped anatomical (T1w or T2w) image.
    std2anat_xfm : :obj:`str`
        inverse registration transform of T1w image to MNI template
    anat2epi_xfm : :obj:`str`
        transform mapping coordinates from the EPI space to the anatomical
        space (i.e., the transform to resample anatomical info into EPI space.)

    Outputs
    -------
    fmap : :obj:`str`
        The path of the estimated fieldmap.
    fmap_ref : :obj:`str`
        The path of an unwarped conversion of files in ``epi_ref``.
    fmap_coeff : :obj:`str` or :obj:`list` of :obj:`str`
        The path(s) of the B-Spline coefficients supporting the fieldmap.

    """
    from pkg_resources import resource_filename as pkgrf
    from packaging.version import parse as parseversion, Version
    from nipype.interfaces.image import Rescale
    from niworkflows.interfaces.fixes import (
        FixHeaderApplyTransforms as ApplyTransforms,
        FixHeaderRegistration as Registration,
    )
    from niworkflows.interfaces.nibabel import Binarize
    from ...utils.misc import front as _pop
    from ...interfaces.utils import Deoblique, Reoblique
    from ...interfaces.bspline import (
        BSplineApprox,
        DEFAULT_LF_ZOOMS_MM,
        DEFAULT_HF_ZOOMS_MM,
        DEFAULT_ZOOMS_MM,
    )
    from ..ancillary import init_brainextraction_wf

    ants_version = Registration().version
    if ants_version and parseversion(ants_version) < Version("2.2.0"):
        raise RuntimeError(
            f"Please upgrade ANTs to 2.2 or older ({ants_version} found).")

    workflow = Workflow(name=name)
    workflow.__desc__ = f"""\
A deformation field to correct for susceptibility distortions was estimated
based on *fMRIPrep*'s *fieldmap-less* approach.
The deformation field is that resulting from co-registering the EPI reference
to the same-subject T1w-reference with its intensity inverted [@fieldmapless1;
@fieldmapless2].
Registration is performed with `antsRegistration`
(ANTs {ants_version or "-- version unknown"}), and
the process regularized by constraining deformation to be nonzero only
along the phase-encoding direction, and modulated with an average fieldmap
template [@fieldmapless3].
"""
    inputnode = pe.Node(
        niu.IdentityInterface([
            "epi_ref", "epi_mask", "anat_brain", "std2anat_xfm", "anat2epi_xfm"
        ]),
        name="inputnode",
    )
    outputnode = pe.Node(
        niu.IdentityInterface(["fmap", "fmap_ref", "fmap_coeff", "fmap_mask"]),
        name="outputnode",
    )

    invert_t1w = pe.Node(Rescale(invert=True), name="invert_t1w", mem_gb=0.3)
    anat2epi = pe.Node(ApplyTransforms(interpolation="BSpline"),
                       name="anat2epi",
                       n_procs=omp_nthreads)

    # Mapping & preparing prior knowledge
    # Concatenate transform files:
    # 1) anat -> EPI; 2) MNI -> anat; 3) ATLAS -> MNI
    transform_list = pe.Node(niu.Merge(3),
                             name="transform_list",
                             mem_gb=DEFAULT_MEMORY_MIN_GB)
    transform_list.inputs.in3 = pkgrf(
        "sdcflows", "data/fmap_atlas_2_MNI152NLin2009cAsym_affine.mat")
    prior2epi = pe.Node(
        ApplyTransforms(
            input_image=pkgrf("sdcflows", "data/fmap_atlas.nii.gz")),
        name="prior2epi",
        n_procs=omp_nthreads,
        mem_gb=0.3,
    )
    atlas_msk = pe.Node(Binarize(thresh_low=atlas_threshold), name="atlas_msk")

    deoblique = pe.Node(Deoblique(), name="deoblique")
    reoblique = pe.Node(Reoblique(), name="reoblique")

    # SyN Registration Core
    syn = pe.Node(
        Registration(
            from_file=pkgrf("sdcflows", "data/susceptibility_syn.json")),
        name="syn",
        n_procs=omp_nthreads,
    )

    unwarp_ref = pe.Node(
        ApplyTransforms(interpolation="BSpline"),
        name="unwarp_ref",
    )

    brainextraction_wf = init_brainextraction_wf()

    # Extract nonzero component
    extract_field = pe.Node(niu.Function(function=_extract_field),
                            name="extract_field")

    # Regularize with B-Splines
    bs_filter = pe.Node(BSplineApprox(),
                        n_procs=omp_nthreads,
                        name="bs_filter")
    bs_filter.interface._always_run = debug
    bs_filter.inputs.bs_spacing = ([DEFAULT_LF_ZOOMS_MM, DEFAULT_HF_ZOOMS_MM]
                                   if not debug else [DEFAULT_ZOOMS_MM])
    bs_filter.inputs.extrapolate = not debug

    # fmt: off
    workflow.connect([
        (inputnode, transform_list, [("anat2epi_xfm", "in1"),
                                     ("std2anat_xfm", "in2")]),
        (inputnode, invert_t1w, [("anat_brain", "in_file"),
                                 (("epi_ref", _pop), "ref_file")]),
        (inputnode, anat2epi, [(("epi_ref", _pop), "reference_image"),
                               ("anat2epi_xfm", "transforms")]),
        (inputnode, deoblique, [(("epi_ref", _pop), "in_epi"),
                                ("epi_mask", "mask_epi")]),
        (inputnode, reoblique, [(("epi_ref", _pop), "in_epi")]),
        (inputnode, syn, [(("epi_ref", _warp_dir), "restrict_deformation")]),
        (inputnode, unwarp_ref, [(("epi_ref", _pop), "reference_image"),
                                 (("epi_ref", _pop), "input_image")]),
        (inputnode, prior2epi, [(("epi_ref", _pop), "reference_image")]),
        (inputnode, extract_field, [("epi_ref", "epi_meta")]),
        (invert_t1w, anat2epi, [("out_file", "input_image")]),
        (transform_list, prior2epi, [("out", "transforms")]),
        (prior2epi, atlas_msk, [("output_image", "in_file")]),
        (anat2epi, deoblique, [("output_image", "in_anat")]),
        (atlas_msk, deoblique, [("out_mask", "mask_anat")]),
        (deoblique, syn, [("out_epi", "moving_image"),
                          ("out_anat", "fixed_image"),
                          ("mask_epi", "moving_image_masks"),
                          (("mask_anat", _fixed_masks_arg),
                           "fixed_image_masks")]),
        (syn, extract_field, [("forward_transforms", "in_file")]),
        (syn, unwarp_ref, [("forward_transforms", "transforms")]),
        (unwarp_ref, reoblique, [("output_image", "in_plumb")]),
        (reoblique, brainextraction_wf, [("out_epi", "inputnode.in_file")]),
        (extract_field, reoblique, [("out", "in_field")]),
        (reoblique, bs_filter, [("out_field", "in_data")]),
        (brainextraction_wf, bs_filter, [("outputnode.out_mask", "in_mask")]),
        (reoblique, outputnode, [("out_epi", "fmap_ref")]),
        (brainextraction_wf, outputnode, [("outputnode.out_mask", "fmap_mask")
                                          ]),
        (bs_filter, outputnode,
         [("out_extrapolated" if not debug else "out_field", "fmap"),
          ("out_coeff", "fmap_coeff")]),
    ])
    # fmt: on

    return workflow
Example #6
0
def init_coregistration_wf(
    *,
    bspline_fitting_distance=200,
    mem_gb=3.0,
    name="coregistration_wf",
    omp_nthreads=None,
    sloppy=False,
    debug=False,
):
    """
    Set-up a T2w-to-T1w within-baby co-registration framework.

    See the ANTs' registration config file (under ``nibabies/data``) for further
    details.
    The main surprise in it is that, for some participants, accurate registration
    requires extra degrees of freedom (one affine level and one SyN level) to ensure
    that the T1w and T2w images align well.
    I attribute this requirement to the following potential reasons:

      * The T1w image and the T2w image were acquired in different sessions, apart in
        time enough for growth to happen.
        Although this is, in theory possible, it doesn't seem the images we have tested
        on are acquired on different sessions.
      * The skull is still so malleable that a change of position of the baby inside the
        coil made an actual change on the overall shape of their head.
      * Nonlinear distortions of the T1w and T2w images are, for some reason, more notorious
        for babies than they are for adults.
        We would need to look into each sequence's details to confirm this.

    Parameters
    ----------
    bspline_fitting_distance : :obj:`float`
        Distance in mm between B-Spline control points for N4 INU estimation.
    mem_gb : :obj:`float`
        Base memory fingerprint unit.
    name : :obj:`str`
        This particular workflow's unique name (Nipype requirement).
    omp_nthreads : :obj:`int`
        The number of threads for individual processes in this workflow.
    sloppy : :obj:`bool`
        Run in *sloppy* mode.
    debug : :obj:`bool`
        Produce intermediate registration files


    Inputs
    ------
    in_t1w : :obj:`str`
        The unprocessed input T1w image.
    in_t2w_preproc : :obj:`str`
        The preprocessed input T2w image, from the brain extraction workflow.
    in_mask : :obj:`str`
        The brainmask, as obtained in T2w space.
    in_probmap : :obj:`str`
        The probabilistic brainmask, as obtained in T2w space.

    Outputs
    -------
    t1w_preproc : :obj:`str`
        The preprocessed T1w image (INU and clipping).
    t2w_preproc : :obj:`str`
        The preprocessed T2w image (INU and clipping), aligned into the T1w's space.
    t1w_brain : :obj:`str`
        The preprocessed, brain-extracted T1w image.
    t1w_mask : :obj:`str`
        The binary brainmask projected from the T2w.
    t1w2t2w_xfm : :obj:`str`
        The T1w-to-T2w mapping.

    """
    from nipype.interfaces.ants import N4BiasFieldCorrection
    from niworkflows.interfaces.fixes import (
        FixHeaderRegistration as Registration,
        FixHeaderApplyTransforms as ApplyTransforms,
    )
    from niworkflows.interfaces.nibabel import ApplyMask, Binarize
    from ...interfaces.nibabel import BinaryDilation

    workflow = pe.Workflow(name)

    inputnode = pe.Node(
        niu.IdentityInterface(
            fields=["in_t1w", "in_t2w_preproc", "in_mask", "in_probmap"]),
        name="inputnode",
    )
    outputnode = pe.Node(
        niu.IdentityInterface(fields=[
            "t1w_preproc",
            "t1w_brain",
            "t1w_mask",
            "t1w2t2w_xfm",
            "t2w_preproc",
        ]),
        name="outputnode",
    )

    fixed_masks_arg = pe.Node(niu.Merge(3),
                              name="fixed_masks_arg",
                              run_without_submitting=True)

    # Dilate t2w mask for easier t1->t2 registration
    reg_mask = pe.Node(BinaryDilation(radius=8, iterations=3), name="reg_mask")
    refine_mask = pe.Node(BinaryDilation(radius=8, iterations=1),
                          name="refine_mask")

    # Set up T2w -> T1w within-subject registration
    coreg = pe.Node(
        Registration(
            from_file=pkgr_fn("nibabies.data", "within_subject_t1t2.json")),
        name="coreg",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )
    coreg.inputs.float = sloppy
    if debug:
        coreg.inputs.args = "--write-interval-volumes 5"
        coreg.inputs.output_inverse_warped_image = sloppy
        coreg.inputs.output_warped_image = sloppy

    map_mask = pe.Node(ApplyTransforms(interpolation="Gaussian"),
                       name="map_mask",
                       mem_gb=1)
    map_t2w = pe.Node(ApplyTransforms(interpolation="BSpline"),
                      name="map_t2w",
                      mem_gb=1)
    thr_mask = pe.Node(Binarize(thresh_low=0.80), name="thr_mask")

    final_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            bspline_fitting_distance=bspline_fitting_distance,
            save_bias=True,
            copy_header=True,
            n_iterations=[50] * 5,
            convergence_threshold=1e-7,
            rescale_intensities=True,
            shrink_factor=4,
        ),
        n_procs=omp_nthreads,
        name="final_n4",
    )
    apply_mask = pe.Node(ApplyMask(), name="apply_mask")

    # fmt:off
    workflow.connect([
        (inputnode, map_mask, [("in_t1w", "reference_image")]),
        (inputnode, final_n4, [("in_t1w", "input_image")]),
        (inputnode, coreg, [("in_t1w", "moving_image"),
                            ("in_t2w_preproc", "fixed_image")]),
        (inputnode, map_mask, [("in_probmap", "input_image")]),
        (inputnode, reg_mask, [("in_mask", "in_file")]),
        (inputnode, refine_mask, [("in_mask", "in_file")]),
        (reg_mask, fixed_masks_arg, [("out_file", "in1")]),
        (reg_mask, fixed_masks_arg, [("out_file", "in2")]),
        (refine_mask, fixed_masks_arg, [("out_file", "in3")]),
        (inputnode, map_t2w, [("in_t1w", "reference_image")]),
        (inputnode, map_t2w, [("in_t2w_preproc", "input_image")]),
        (fixed_masks_arg, coreg, [("out", "fixed_image_masks")]),
        (coreg, map_mask, [
            ("reverse_transforms", "transforms"),
            ("reverse_invert_flags", "invert_transform_flags"),
        ]),
        (coreg, map_t2w, [
            ("reverse_transforms", "transforms"),
            ("reverse_invert_flags", "invert_transform_flags"),
        ]),
        (map_mask, thr_mask, [("output_image", "in_file")]),
        (map_mask, final_n4, [("output_image", "weight_image")]),
        (final_n4, apply_mask, [("output_image", "in_file")]),
        (thr_mask, apply_mask, [("out_mask", "in_mask")]),
        (final_n4, outputnode, [("output_image", "t1w_preproc")]),
        (map_t2w, outputnode, [("output_image", "t2w_preproc")]),
        (thr_mask, outputnode, [("out_mask", "t1w_mask")]),
        (apply_mask, outputnode, [("out_file", "t1w_brain")]),
        (coreg, outputnode, [("forward_transforms", "t1w2t2w_xfm")]),
    ])
    # fmt:on
    return workflow
Example #7
0
def init_infant_brain_extraction_wf(
    age_months=None,
    ants_affine_init=False,
    bspline_fitting_distance=200,
    sloppy=False,
    skull_strip_template="UNCInfant",
    template_specs=None,
    interim_checkpoints=True,
    mem_gb=3.0,
    mri_scheme="T1w",
    name="infant_brain_extraction_wf",
    atropos_model=None,
    omp_nthreads=None,
    output_dir=None,
    use_float=True,
    use_t2w=False,
):
    """
    Build an atlas-based brain extraction pipeline for infant T1w/T2w MRI data.

    Pros/Cons of available templates
    --------------------------------
    * MNIInfant
     + More cohorts available for finer-grain control
     + T1w/T2w images available
     - Template masks are poor

    * UNCInfant
     + Accurate masks
     - No T2w image available


    Parameters
    ----------
    ants_affine_init : :obj:`bool`, optional
        Set-up a pre-initialization step with ``antsAI`` to account for mis-oriented images.

    """
    # handle template specifics
    template_specs = template_specs or {}
    if skull_strip_template == 'MNIInfant':
        template_specs['resolution'] = 2 if sloppy else 1

    if not template_specs.get('cohort'):
        if age_months is None:
            raise KeyError(
                f"Age or cohort for {skull_strip_template} must be provided!")
        template_specs['cohort'] = cohort_by_months(skull_strip_template,
                                                    age_months)

    inputnode = pe.Node(
        niu.IdentityInterface(fields=["t1w", "t2w", "in_mask"]),
        name="inputnode")
    outputnode = pe.Node(niu.IdentityInterface(
        fields=["t1w_corrected", "t1w_corrected_brain", "t1w_mask"]),
                         name="outputnode")

    if not use_t2w:
        raise RuntimeError("A T2w image is currently required.")

    tpl_target_path = get_template(
        skull_strip_template,
        suffix='T1w',  # no T2w template
        desc=None,
        **template_specs,
    )
    if not tpl_target_path:
        raise RuntimeError(
            f"An instance of template <tpl-{skull_strip_template}> with MR scheme "
            f"'{'T1w' or mri_scheme}' could not be found.")

    tpl_brainmask_path = get_template(skull_strip_template,
                                      label="brain",
                                      suffix="probseg",
                                      **template_specs) or get_template(
                                          skull_strip_template,
                                          desc="brain",
                                          suffix="mask",
                                          **template_specs)

    tpl_regmask_path = get_template(skull_strip_template,
                                    label="BrainCerebellumExtraction",
                                    suffix="mask",
                                    **template_specs)

    # validate images
    val_tmpl = pe.Node(ValidateImage(), name='val_tmpl')
    val_t1w = val_tmpl.clone("val_t1w")
    val_t2w = val_tmpl.clone("val_t2w")
    val_tmpl.inputs.in_file = _pop(tpl_target_path)

    gauss_tmpl = pe.Node(niu.Function(function=_gauss_filter),
                         name="gauss_tmpl")

    # Spatial normalization step
    lap_tmpl = pe.Node(ImageMath(operation="Laplacian", op2="0.4 1"),
                       name="lap_tmpl")
    lap_t1w = lap_tmpl.clone("lap_t1w")
    lap_t2w = lap_tmpl.clone("lap_t2w")

    # Merge image nodes
    mrg_tmpl = pe.Node(niu.Merge(2), name="mrg_tmpl")
    mrg_t2w = mrg_tmpl.clone("mrg_t2w")
    mrg_t1w = mrg_tmpl.clone("mrg_t1w")

    norm_lap_tmpl = pe.Node(niu.Function(function=_trunc),
                            name="norm_lap_tmpl")
    norm_lap_tmpl.inputs.dtype = "float32"
    norm_lap_tmpl.inputs.out_max = 1.0
    norm_lap_tmpl.inputs.percentile = (0.01, 99.99)
    norm_lap_tmpl.inputs.clip_max = None

    norm_lap_t1w = norm_lap_tmpl.clone('norm_lap_t1w')
    norm_lap_t2w = norm_lap_t1w.clone('norm_lap_t2w')

    # Set up initial spatial normalization
    ants_params = "testing" if sloppy else "precise"
    norm = pe.Node(
        Registration(from_file=pkgr_fn(
            "niworkflows.data", f"antsBrainExtraction_{ants_params}.json")),
        name="norm",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )
    norm.inputs.float = use_float
    if tpl_regmask_path:
        norm.inputs.fixed_image_masks = tpl_regmask_path

    # Set up T2w -> T1w within-subject registration
    norm_subj = pe.Node(
        Registration(
            from_file=pkgr_fn("nibabies.data", "within_subject_t1t2.json")),
        name="norm_subj",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )
    norm_subj.inputs.float = use_float

    # main workflow
    wf = pe.Workflow(name)
    # Create a buffer interface as a cache for the actual inputs to registration
    buffernode = pe.Node(
        niu.IdentityInterface(fields=["hires_target", "smooth_target"]),
        name="buffernode")

    # truncate target intensity for N4 correction
    clip_tmpl = pe.Node(niu.Function(function=_trunc), name="clip_tmpl")
    clip_t2w = clip_tmpl.clone('clip_t2w')
    clip_t1w = clip_tmpl.clone('clip_t1w')

    # INU correction of the t1w
    init_t2w_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=False,
            copy_header=True,
            n_iterations=[50] * (4 - sloppy),
            convergence_threshold=1e-7,
            shrink_factor=4,
            bspline_fitting_distance=bspline_fitting_distance,
        ),
        n_procs=omp_nthreads,
        name="init_t2w_n4",
    )
    init_t1w_n4 = init_t2w_n4.clone("init_t1w_n4")

    clip_t2w_inu = pe.Node(niu.Function(function=_trunc), name="clip_t2w_inu")
    clip_t1w_inu = clip_t2w_inu.clone("clip_t1w_inu")

    map_mask_t2w = pe.Node(ApplyTransforms(interpolation="Gaussian",
                                           float=True),
                           name="map_mask_t2w",
                           mem_gb=1)
    map_mask_t1w = map_mask_t2w.clone("map_mask_t1w")

    # map template brainmask to t2w space
    map_mask_t2w.inputs.input_image = str(tpl_brainmask_path)

    thr_t2w_mask = pe.Node(Binarize(thresh_low=0.80), name="thr_t2w_mask")
    thr_t1w_mask = thr_t2w_mask.clone('thr_t1w_mask')

    bspline_grid = pe.Node(niu.Function(function=_bspline_distance),
                           name="bspline_grid")

    # Refine INU correction
    final_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            bspline_fitting_distance=bspline_fitting_distance,
            save_bias=True,
            copy_header=True,
            n_iterations=[50] * 5,
            convergence_threshold=1e-7,
            rescale_intensities=True,
            shrink_factor=4,
        ),
        n_procs=omp_nthreads,
        name="final_n4",
    )
    final_mask = pe.Node(ApplyMask(), name="final_mask")

    if atropos_model is None:
        atropos_model = tuple(ATROPOS_MODELS[mri_scheme].values())

    atropos_wf = init_atropos_wf(
        use_random_seed=False,
        omp_nthreads=omp_nthreads,
        mem_gb=mem_gb,
        in_segmentation_model=atropos_model,
    )
    # if tpl_regmask_path:
    #     atropos_wf.get_node('inputnode').inputs.in_mask_dilated = tpl_regmask_path

    sel_wm = pe.Node(niu.Select(index=atropos_model[-1] - 1),
                     name='sel_wm',
                     run_without_submitting=True)

    wf.connect([
        # 1. massage template
        (val_tmpl, clip_tmpl, [("out_file", "in_file")]),
        (clip_tmpl, lap_tmpl, [("out", "op1")]),
        (clip_tmpl, mrg_tmpl, [("out", "in1")]),
        (lap_tmpl, norm_lap_tmpl, [("output_image", "in_file")]),
        (norm_lap_tmpl, mrg_tmpl, [("out", "in2")]),
        # 2. massage T2w
        (inputnode, val_t2w, [('t2w', 'in_file')]),
        (val_t2w, clip_t2w, [('out_file', 'in_file')]),
        (clip_t2w, init_t2w_n4, [('out', 'input_image')]),
        (init_t2w_n4, clip_t2w_inu, [("output_image", "in_file")]),
        (clip_t2w_inu, lap_t2w, [('out', 'op1')]),
        (clip_t2w_inu, mrg_t2w, [('out', 'in1')]),
        (lap_t2w, norm_lap_t2w, [("output_image", "in_file")]),
        (norm_lap_t2w, mrg_t2w, [("out", "in2")]),
        # 3. normalize T2w to target template (UNC)
        (mrg_t2w, norm, [("out", "moving_image")]),
        (mrg_tmpl, norm, [("out", "fixed_image")]),
        # 4. map template brainmask to T2w space
        (inputnode, map_mask_t2w, [('t2w', 'reference_image')]),
        (norm, map_mask_t2w, [("reverse_transforms", "transforms"),
                              ("reverse_invert_flags",
                               "invert_transform_flags")]),
        (map_mask_t2w, thr_t2w_mask, [("output_image", "in_file")]),
        # 5. massage T1w
        (inputnode, val_t1w, [("t1w", "in_file")]),
        (val_t1w, clip_t1w, [("out_file", "in_file")]),
        (clip_t1w, init_t1w_n4, [("out", "input_image")]),
        (init_t1w_n4, clip_t1w_inu, [("output_image", "in_file")]),
        (clip_t1w_inu, lap_t1w, [('out', 'op1')]),
        (clip_t1w_inu, mrg_t1w, [('out', 'in1')]),
        (lap_t1w, norm_lap_t1w, [("output_image", "in_file")]),
        (norm_lap_t1w, mrg_t1w, [("out", "in2")]),
        # 6. normalize within subject T1w to T2w
        (mrg_t1w, norm_subj, [("out", "moving_image")]),
        (mrg_t2w, norm_subj, [("out", "fixed_image")]),
        (thr_t2w_mask, norm_subj, [("out_mask", "fixed_image_mask")]),
        # 7. map mask to T1w space
        (thr_t2w_mask, map_mask_t1w, [("out_mask", "input_image")]),
        (inputnode, map_mask_t1w, [("t1w", "reference_image")]),
        (norm_subj, map_mask_t1w, [
            ("reverse_transforms", "transforms"),
            ("reverse_invert_flags", "invert_transform_flags"),
        ]),
        (map_mask_t1w, thr_t1w_mask, [("output_image", "in_file")]),
        # 8. T1w INU
        (inputnode, final_n4, [("t1w", "input_image")]),
        (inputnode, bspline_grid, [("t1w", "in_file")]),
        (bspline_grid, final_n4, [("out", "args")]),
        (map_mask_t1w, final_n4, [("output_image", "weight_image")]),
        (final_n4, final_mask, [("output_image", "in_file")]),
        (thr_t1w_mask, final_mask, [("out_mask", "in_mask")]),
        # 9. Outputs
        (final_n4, outputnode, [("output_image", "t1w_corrected")]),
        (thr_t1w_mask, outputnode, [("out_mask", "t1w_mask")]),
        (final_mask, outputnode, [("out_file", "t1w_corrected_brain")]),
    ])

    if ants_affine_init:
        ants_kwargs = dict(
            metric=("Mattes", 32, "Regular", 0.2),
            transform=("Affine", 0.1),
            search_factor=(20, 0.12),
            principal_axes=False,
            convergence=(10, 1e-6, 10),
            search_grid=(40, (0, 40, 40)),
            verbose=True,
        )

        if ants_affine_init == 'random':
            ants_kwargs['metric'] = ("Mattes", 32, "Random", 0.2)
        if ants_affine_init == 'search':
            ants_kwargs['search_grid'] = (20, (20, 40, 40))

        init_aff = pe.Node(
            AI(**ants_kwargs),
            name="init_aff",
            n_procs=omp_nthreads,
        )
        if tpl_regmask_path:
            init_aff.inputs.fixed_image_mask = _pop(tpl_regmask_path)

        wf.connect([
            (clip_tmpl, init_aff, [("out", "fixed_image")]),
            (clip_t2w_inu, init_aff, [("out", "moving_image")]),
            (init_aff, norm, [("output_transform", "initial_moving_transform")
                              ]),
        ])

    return wf
Example #8
0
def init_infant_brain_extraction_wf(
    ants_affine_init=False,
    bspline_fitting_distance=200,
    debug=False,
    in_template="MNIInfant",
    template_specs=None,
    interim_checkpoints=True,
    mem_gb=3.0,
    mri_scheme="T2w",
    name="infant_brain_extraction_wf",
    atropos_model=None,
    omp_nthreads=None,
    output_dir=None,
    use_float=True,
):
    """
    Build an atlas-based brain extraction pipeline for infant T2w MRI data.

    Parameters
    ----------
    ants_affine_init : :obj:`bool`, optional
        Set-up a pre-initialization step with ``antsAI`` to account for mis-oriented images.

    """
    inputnode = pe.Node(niu.IdentityInterface(fields=["in_files", "in_mask"]),
                        name="inputnode")
    outputnode = pe.Node(niu.IdentityInterface(
        fields=["out_corrected", "out_brain", "out_mask"]),
                         name="outputnode")

    template_specs = template_specs or {}
    # Find a suitable target template in TemplateFlow
    tpl_target_path = get_template(in_template,
                                   suffix=mri_scheme,
                                   **template_specs)
    if not tpl_target_path:
        raise RuntimeError(
            f"An instance of template <tpl-{in_template}> with MR scheme '{mri_scheme}'"
            " could not be found.")

    # tpl_brainmask_path = get_template(
    #     in_template, desc="brain", suffix="probseg", **template_specs
    # )
    # if not tpl_brainmask_path:

    # ignore probseg for the time being
    tpl_brainmask_path = get_template(in_template,
                                      desc="brain",
                                      suffix="mask",
                                      **template_specs)

    tpl_regmask_path = get_template(in_template,
                                    desc="BrainCerebellumExtraction",
                                    suffix="mask",
                                    **template_specs)

    # validate images
    val_tmpl = pe.Node(ValidateImage(), name='val_tmpl')
    val_tmpl.inputs.in_file = _pop(tpl_target_path)

    val_target = pe.Node(ValidateImage(), name='val_target')

    # Resample both target and template to a controlled, isotropic resolution
    res_tmpl = pe.Node(RegridToZooms(zooms=HIRES_ZOOMS),
                       name="res_tmpl")  # testing
    res_target = pe.Node(RegridToZooms(zooms=HIRES_ZOOMS),
                         name="res_target")  # testing
    gauss_tmpl = pe.Node(niu.Function(function=_gauss_filter),
                         name="gauss_tmpl")

    # Spatial normalization step
    lap_tmpl = pe.Node(ImageMath(operation="Laplacian", op2="0.4 1"),
                       name="lap_tmpl")
    lap_target = pe.Node(ImageMath(operation="Laplacian", op2="0.4 1"),
                         name="lap_target")

    # Merge image nodes
    mrg_target = pe.Node(niu.Merge(2), name="mrg_target")
    mrg_tmpl = pe.Node(niu.Merge(2), name="mrg_tmpl")

    norm_lap_tmpl = pe.Node(niu.Function(function=_trunc),
                            name="norm_lap_tmpl")
    norm_lap_tmpl.inputs.dtype = "float32"
    norm_lap_tmpl.inputs.out_max = 1.0
    norm_lap_tmpl.inputs.percentile = (0.01, 99.99)
    norm_lap_tmpl.inputs.clip_max = None

    norm_lap_target = pe.Node(niu.Function(function=_trunc),
                              name="norm_lap_target")
    norm_lap_target.inputs.dtype = "float32"
    norm_lap_target.inputs.out_max = 1.0
    norm_lap_target.inputs.percentile = (0.01, 99.99)
    norm_lap_target.inputs.clip_max = None

    # Set up initial spatial normalization
    ants_params = "testing" if debug else "precise"
    norm = pe.Node(
        Registration(from_file=pkgr_fn(
            "niworkflows.data", f"antsBrainExtraction_{ants_params}.json")),
        name="norm",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )
    norm.inputs.float = use_float

    # main workflow
    wf = pe.Workflow(name)
    # Create a buffer interface as a cache for the actual inputs to registration
    buffernode = pe.Node(
        niu.IdentityInterface(fields=["hires_target", "smooth_target"]),
        name="buffernode")

    # truncate target intensity for N4 correction
    clip_target = pe.Node(
        niu.Function(function=_trunc),
        name="clip_target",
    )
    clip_tmpl = pe.Node(
        niu.Function(function=_trunc),
        name="clip_tmpl",
    )
    #clip_tmpl.inputs.in_file = _pop(tpl_target_path)

    # INU correction of the target image
    init_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=False,
            copy_header=True,
            n_iterations=[50] * (4 - debug),
            convergence_threshold=1e-7,
            shrink_factor=4,
            bspline_fitting_distance=bspline_fitting_distance,
        ),
        n_procs=omp_nthreads,
        name="init_n4",
    )
    clip_inu = pe.Node(
        niu.Function(function=_trunc),
        name="clip_inu",
    )
    gauss_target = pe.Node(niu.Function(function=_gauss_filter),
                           name="gauss_target")
    wf.connect([
        # truncation, resampling, and initial N4
        (inputnode, val_target, [(("in_files", _pop), "in_file")]),
        # (inputnode, res_target, [(("in_files", _pop), "in_file")]),
        (val_target, res_target, [("out_file", "in_file")]),
        (res_target, clip_target, [("out_file", "in_file")]),
        (val_tmpl, clip_tmpl, [("out_file", "in_file")]),
        (clip_tmpl, res_tmpl, [("out", "in_file")]),
        (clip_target, init_n4, [("out", "input_image")]),
        (init_n4, clip_inu, [("output_image", "in_file")]),
        (clip_inu, gauss_target, [("out", "in_file")]),
        (clip_inu, buffernode, [("out", "hires_target")]),
        (gauss_target, buffernode, [("out", "smooth_target")]),
        (res_tmpl, gauss_tmpl, [("out_file", "in_file")]),
        # (clip_tmpl, gauss_tmpl, [("out", "in_file")]),
    ])

    # Graft a template registration-mask if present
    if tpl_regmask_path:
        hires_mask = pe.Node(ApplyTransforms(
            input_image=_pop(tpl_regmask_path),
            transforms="identity",
            interpolation="NearestNeighbor",
            float=True),
                             name="hires_mask",
                             mem_gb=1)
        wf.connect([
            (res_tmpl, hires_mask, [("out_file", "reference_image")]),
        ])

    map_brainmask = pe.Node(ApplyTransforms(interpolation="Gaussian",
                                            float=True),
                            name="map_brainmask",
                            mem_gb=1)
    map_brainmask.inputs.input_image = str(tpl_brainmask_path)

    thr_brainmask = pe.Node(Binarize(thresh_low=0.80), name="thr_brainmask")
    bspline_grid = pe.Node(niu.Function(function=_bspline_distance),
                           name="bspline_grid")

    # Refine INU correction
    final_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=True,
            copy_header=True,
            n_iterations=[50] * 5,
            convergence_threshold=1e-7,
            rescale_intensities=True,
            shrink_factor=4,
        ),
        n_procs=omp_nthreads,
        name="final_n4",
    )
    final_mask = pe.Node(ApplyMask(), name="final_mask")

    if atropos_model is None:
        atropos_model = tuple(ATROPOS_MODELS[mri_scheme].values())

    atropos_wf = init_atropos_wf(
        use_random_seed=False,
        omp_nthreads=omp_nthreads,
        mem_gb=mem_gb,
        in_segmentation_model=atropos_model,
    )
    # if tpl_regmask_path:
    #     atropos_wf.get_node('inputnode').inputs.in_mask_dilated = tpl_regmask_path

    sel_wm = pe.Node(niu.Select(index=atropos_model[-1] - 1),
                     name='sel_wm',
                     run_without_submitting=True)

    wf.connect([
        (inputnode, map_brainmask, [(("in_files", _pop), "reference_image")]),
        (inputnode, final_n4, [(("in_files", _pop), "input_image")]),
        (inputnode, bspline_grid, [(("in_files", _pop), "in_file")]),
        # (bspline_grid, final_n4, [("out", "bspline_fitting_distance")]),
        (bspline_grid, final_n4, [("out", "args")]),
        # merge laplacian and original images
        (buffernode, lap_target, [("smooth_target", "op1")]),
        (buffernode, mrg_target, [("hires_target", "in1")]),
        (lap_target, norm_lap_target, [("output_image", "in_file")]),
        (norm_lap_target, mrg_target, [("out", "in2")]),
        # Template massaging
        (res_tmpl, lap_tmpl, [("out_file", "op1")]),
        (res_tmpl, mrg_tmpl, [("out_file", "in1")]),
        (lap_tmpl, norm_lap_tmpl, [("output_image", "in_file")]),
        (norm_lap_tmpl, mrg_tmpl, [("out", "in2")]),
        # spatial normalization
        (mrg_target, norm, [("out", "moving_image")]),
        (mrg_tmpl, norm, [("out", "fixed_image")]),
        (norm, map_brainmask, [("reverse_transforms", "transforms"),
                               ("reverse_invert_flags",
                                "invert_transform_flags")]),
        (map_brainmask, thr_brainmask, [("output_image", "in_file")]),
        # take a second pass of N4
        (map_brainmask, final_n4, [("output_image", "weight_image")]),
        (final_n4, final_mask, [("output_image", "in_file")]),
        (thr_brainmask, final_mask, [("out_mask", "in_mask")]),
        (final_n4, outputnode, [("output_image", "out_corrected")]),
        (thr_brainmask, outputnode, [("out_mask", "out_mask")]),
        (final_mask, outputnode, [("out_file", "out_brain")]),
    ])

    # wf.disconnect([
    #     (get_brainmask, apply_mask, [('output_image', 'mask_file')]),
    #     (copy_xform, outputnode, [('out_mask', 'out_mask')]),
    # ])

    # wf.connect([
    #     (init_n4, atropos_wf, [
    #         ('output_image', 'inputnode.in_files')]),  # intensity image
    #     (thr_brainmask, atropos_wf, [
    #         ('out_mask', 'inputnode.in_mask')]),
    #     (atropos_wf, sel_wm, [('outputnode.out_tpms', 'inlist')]),
    #     (sel_wm, final_n4, [('out', 'weight_image')]),
    # ])
    # wf.connect([
    # (atropos_wf, outputnode, [
    #     ('outputnode.out_mask', 'out_mask'),
    #     ('outputnode.out_segm', 'out_segm'),
    #     ('outputnode.out_tpms', 'out_tpms')]),
    # ])

    if tpl_regmask_path:
        wf.connect([
            (hires_mask, norm, [("output_image", "fixed_image_masks")]),
            # (hires_mask, atropos_wf, [
            #     ("output_image", "inputnode.in_mask_dilated")]),
        ])

    if interim_checkpoints:
        final_apply = pe.Node(ApplyTransforms(interpolation="BSpline",
                                              float=True),
                              name="final_apply",
                              mem_gb=1)
        final_report = pe.Node(SimpleBeforeAfter(
            before_label=f"tpl-{in_template}",
            after_label="target",
            out_report="final_report.svg"),
                               name="final_report")
        wf.connect([
            (inputnode, final_apply, [(("in_files", _pop), "reference_image")
                                      ]),
            (res_tmpl, final_apply, [("out_file", "input_image")]),
            (norm, final_apply, [("reverse_transforms", "transforms"),
                                 ("reverse_invert_flags",
                                  "invert_transform_flags")]),
            (final_apply, final_report, [("output_image", "before")]),
            (outputnode, final_report, [("out_corrected", "after"),
                                        ("out_mask", "wm_seg")]),
        ])

    if output_dir:
        from nipype.interfaces.io import DataSink
        ds_final_inu = pe.Node(DataSink(base_directory=str(output_dir.parent)),
                               name="ds_final_inu")
        ds_final_msk = pe.Node(DataSink(base_directory=str(output_dir.parent)),
                               name="ds_final_msk")
        ds_report = pe.Node(DataSink(base_directory=str(output_dir.parent)),
                            name="ds_report")

        wf.connect([
            (outputnode, ds_final_inu,
             [("out_corrected", f"{output_dir.name}.@inu_corrected")]),
            (outputnode, ds_final_msk, [("out_mask",
                                         f"{output_dir.name}.@brainmask")]),
            (final_report, ds_report, [("out_report",
                                        f"{output_dir.name}.@report")]),
        ])

    if not ants_affine_init:
        return wf

    # Initialize transforms with antsAI
    lowres_tmpl = pe.Node(RegridToZooms(zooms=LOWRES_ZOOMS),
                          name="lowres_tmpl")
    lowres_target = pe.Node(RegridToZooms(zooms=LOWRES_ZOOMS),
                            name="lowres_target")

    init_aff = pe.Node(
        AI(
            metric=("Mattes", 32, "Regular", 0.25),
            transform=("Affine", 0.1),
            search_factor=(15, 0.1),
            principal_axes=False,
            convergence=(10, 1e-6, 10),
            search_grid=(40, (0, 40, 40)),
            verbose=True,
        ),
        name="init_aff",
        n_procs=omp_nthreads,
    )
    wf.connect([
        (gauss_tmpl, lowres_tmpl, [("out", "in_file")]),
        (lowres_tmpl, init_aff, [("out_file", "fixed_image")]),
        (gauss_target, lowres_target, [("out", "in_file")]),
        (lowres_target, init_aff, [("out_file", "moving_image")]),
        (init_aff, norm, [("output_transform", "initial_moving_transform")]),
    ])

    if tpl_regmask_path:
        lowres_mask = pe.Node(ApplyTransforms(
            input_image=_pop(tpl_regmask_path),
            transforms="identity",
            interpolation="MultiLabel",
            float=True),
                              name="lowres_mask",
                              mem_gb=1)
        wf.connect([
            (lowres_tmpl, lowres_mask, [("out_file", "reference_image")]),
            (lowres_mask, init_aff, [("output_image", "fixed_image_mask")]),
        ])

    if interim_checkpoints:
        init_apply = pe.Node(ApplyTransforms(interpolation="BSpline",
                                             float=True),
                             name="init_apply",
                             mem_gb=1)
        init_report = pe.Node(SimpleBeforeAfter(
            before_label=f"tpl-{in_template}",
            after_label="target",
            out_report="init_report.svg"),
                              name="init_report")
        wf.connect([
            (lowres_target, init_apply, [("out_file", "input_image")]),
            (res_tmpl, init_apply, [("out_file", "reference_image")]),
            (init_aff, init_apply, [("output_transform", "transforms")]),
            (init_apply, init_report, [("output_image", "after")]),
            (res_tmpl, init_report, [("out_file", "before")]),
        ])

        if output_dir:
            ds_init_report = pe.Node(
                DataSink(base_directory=str(output_dir.parent)),
                name="ds_init_report")
            wf.connect(init_report, "out_report", ds_init_report,
                       f"{output_dir.name}.@init_report")
    return wf