Beispiel #1
0
def test_displacements_field(tmpdir, testdata_dir, outdir, pe_dir, rotation,
                             flip):
    """Check the generated displacements fields."""
    tmpdir.chdir()

    # Generate test oracle
    phantom_nii, coeff_nii = generate_oracle(
        testdata_dir / "topup-coeff-fixed.nii.gz",
        rotation=rotation,
    )

    b0 = tf.B0FieldTransform(coeffs=coeff_nii)
    assert b0.fit(phantom_nii) is True
    assert b0.fit(phantom_nii) is False

    b0.apply(
        phantom_nii,
        pe_dir=pe_dir,
        ro_time=0.2,
        output_dtype="float32",
    ).to_filename("warped-sdcflows.nii.gz")
    b0.to_displacements(
        ro_time=0.2,
        pe_dir=pe_dir,
    ).to_filename("itk-displacements.nii.gz")

    phantom_nii.to_filename("phantom.nii.gz")
    # Run antsApplyTransform
    exit_code = check_call(
        [
            "antsApplyTransforms -d 3 -r phantom.nii.gz -i phantom.nii.gz "
            "-o warped-ants.nii.gz -n BSpline -t itk-displacements.nii.gz"
        ],
        shell=True,
    )
    assert exit_code == 0

    ours = np.asanyarray(nb.load("warped-sdcflows.nii.gz").dataobj)
    theirs = np.asanyarray(nb.load("warped-ants.nii.gz").dataobj)
    assert np.all((np.sqrt(((ours - theirs)**2).sum()) / ours.size) < 1e-1)

    if outdir:
        from niworkflows.interfaces.reportlets.registration import (
            SimpleBeforeAfterRPT as SimpleBeforeAfter, )

        orientation = "".join(
            [ax[bool(f)] for ax, f in zip(("RL", "AP", "SI"), flip)])

        SimpleBeforeAfter(
            after_label="Theirs (ANTs)",
            before_label="Ours (SDCFlows)",
            after="warped-ants.nii.gz",
            before="warped-sdcflows.nii.gz",
            out_report=str(
                outdir /
                f"xfm_pe-{pe_dir}_flip-{orientation}_x-{rotation[0] or 0}"
                f"_y-{rotation[1] or 0}_z-{rotation[2] or 0}.svg"),
        ).run()
Beispiel #2
0
def test_unwarp_wf(tmpdir, datadir, workdir, outdir):
    """Test the unwarping workflow."""
    distorted = (
        datadir
        / "HCP101006"
        / "sub-101006"
        / "func"
        / "sub-101006_task-rest_dir-LR_sbref.nii.gz"
    )

    magnitude = (
        datadir / "HCP101006" / "sub-101006" / "fmap" / "sub-101006_magnitude1.nii.gz"
    )
    fmap_ref_wf = init_magnitude_wf(2, name="fmap_ref_wf")
    fmap_ref_wf.inputs.inputnode.magnitude = magnitude

    epi_ref_wf = init_magnitude_wf(2, name="epi_ref_wf")
    epi_ref_wf.inputs.inputnode.magnitude = distorted

    reg_wf = init_coeff2epi_wf(2, debug=True, write_coeff=True)
    reg_wf.inputs.inputnode.fmap_coeff = [Path(__file__).parent / "fieldcoeff.nii.gz"]

    unwarp_wf = init_unwarp_wf(omp_nthreads=2, debug=True)
    unwarp_wf.inputs.inputnode.metadata = {
        "EffectiveEchoSpacing": 0.00058,
        "PhaseEncodingDirection": "i",
    }

    workflow = pe.Workflow(name="test_unwarp_wf")
    # fmt: off
    workflow.connect([
        (epi_ref_wf, unwarp_wf, [("outputnode.fmap_ref", "inputnode.distorted")]),
        (epi_ref_wf, reg_wf, [
            ("outputnode.fmap_ref", "inputnode.target_ref"),
            ("outputnode.fmap_mask", "inputnode.target_mask"),
        ]),
        (fmap_ref_wf, reg_wf, [
            ("outputnode.fmap_ref", "inputnode.fmap_ref"),
            ("outputnode.fmap_mask", "inputnode.fmap_mask"),
        ]),
        (reg_wf, unwarp_wf, [("outputnode.fmap_coeff", "inputnode.fmap_coeff")]),
    ])
    # fmt:on

    if outdir:
        from niworkflows.interfaces.reportlets.registration import (
            SimpleBeforeAfterRPT as SimpleBeforeAfter,
        )
        from ...outputs import DerivativesDataSink
        from ....interfaces.reportlets import FieldmapReportlet

        report = pe.Node(
            SimpleBeforeAfter(
                before_label="Distorted",
                after_label="Corrected",
            ),
            name="report",
            mem_gb=0.1,
        )
        ds_report = pe.Node(
            DerivativesDataSink(
                base_directory=str(outdir),
                suffix="bold",
                desc="corrected",
                datatype="figures",
                dismiss_entities=("fmap",),
                source_file=distorted,
            ),
            name="ds_report",
            run_without_submitting=True,
        )

        rep = pe.Node(FieldmapReportlet(apply_mask=True), "simple_report")
        rep.interface._always_run = True

        ds_fmap_report = pe.Node(
            DerivativesDataSink(
                base_directory=str(outdir),
                datatype="figures",
                suffix="bold",
                desc="fieldmap",
                dismiss_entities=("fmap",),
                source_file=distorted,
            ),
            name="ds_fmap_report",
        )

        # fmt: off
        workflow.connect([
            (epi_ref_wf, report, [("outputnode.fmap_ref", "before")]),
            (unwarp_wf, report, [("outputnode.corrected", "after"),
                                 ("outputnode.corrected_mask", "wm_seg")]),
            (report, ds_report, [("out_report", "in_file")]),
            (epi_ref_wf, rep, [("outputnode.fmap_ref", "reference"),
                               ("outputnode.fmap_mask", "mask")]),
            (unwarp_wf, rep, [("outputnode.fieldmap", "fieldmap")]),
            (rep, ds_fmap_report, [("out_report", "in_file")]),
        ])
        # fmt: on

    if workdir:
        workflow.base_dir = str(workdir)
    workflow.run(plugin="Linear")
Beispiel #3
0
def init_rodent_brain_extraction_wf(
    ants_affine_init=False,
    factor=20,
    arc=0.12,
    step=4,
    grid=(0, 4, 4),
    debug=False,
    interim_checkpoints=True,
    mem_gb=3.0,
    mri_scheme="T2w",
    name="rodent_brain_extraction_wf",
    omp_nthreads=None,
    output_dir=None,
    template_id="Fischer344",
    template_specs=None,
    use_float=True,
):
    """
    Build an atlas-based brain extraction pipeline for rodent T1w and T2w MRI data.

    Parameters
    ----------
    ants_affine_init : :obj:`bool`, optional
        Set-up a pre-initialization step with ``antsAI`` to account for mis-oriented images.

    """
    inputnode = pe.Node(niu.IdentityInterface(fields=["in_files", "in_mask"]),
                        name="inputnode")
    outputnode = pe.Node(
        niu.IdentityInterface(
            fields=["out_corrected", "out_brain", "out_mask"]),
        name="outputnode",
    )

    template_specs = template_specs or {}
    if template_id == "WHS" and "resolution" not in template_specs:
        template_specs["resolution"] = 2

    # Find a suitable target template in TemplateFlow
    tpl_target_path = get_template(
        template_id,
        suffix=mri_scheme,
        **template_specs,
    )
    if not tpl_target_path:
        raise RuntimeError(
            f"An instance of template <tpl-{template_id}> with MR scheme '{mri_scheme}'"
            " could not be found.")

    tpl_brainmask_path = get_template(
        template_id,
        atlas=None,
        hemi=None,
        desc="brain",
        suffix="probseg",
        **template_specs,
    ) or get_template(
        template_id,
        atlas=None,
        hemi=None,
        desc="brain",
        suffix="mask",
        **template_specs,
    )

    tpl_regmask_path = get_template(
        template_id,
        atlas=None,
        desc="BrainCerebellumExtraction",
        suffix="mask",
        **template_specs,
    )

    denoise = pe.Node(DenoiseImage(dimension=3, copy_header=True),
                      name="denoise",
                      n_procs=omp_nthreads)

    # Resample template to a controlled, isotropic resolution
    res_tmpl = pe.Node(RegridToZooms(zooms=HIRES_ZOOMS, smooth=True),
                       name="res_tmpl")

    # Create Laplacian images
    lap_tmpl = pe.Node(ImageMath(operation="Laplacian", copy_header=True),
                       name="lap_tmpl")
    tmpl_sigma = pe.Node(niu.Function(function=_lap_sigma),
                         name="tmpl_sigma",
                         run_without_submitting=True)
    norm_lap_tmpl = pe.Node(niu.Function(function=_norm_lap),
                            name="norm_lap_tmpl")

    lap_target = pe.Node(ImageMath(operation="Laplacian", copy_header=True),
                         name="lap_target")
    target_sigma = pe.Node(niu.Function(function=_lap_sigma),
                           name="target_sigma",
                           run_without_submitting=True)
    norm_lap_target = pe.Node(niu.Function(function=_norm_lap),
                              name="norm_lap_target")

    # Set up initial spatial normalization
    ants_params = "testing" if debug else "precise"
    norm = pe.Node(
        Registration(from_file=pkgr_fn(
            "nirodents",
            f"data/artsBrainExtraction_{ants_params}_{mri_scheme}.json")),
        name="norm",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )
    norm.inputs.float = use_float

    # main workflow
    wf = pe.Workflow(name)

    # truncate target intensity for N4 correction
    clip_target = pe.Node(IntensityClip(p_min=15, p_max=99.9),
                          name="clip_target")

    # truncate template intensity to match target
    clip_tmpl = pe.Node(IntensityClip(p_min=5, p_max=98), name="clip_tmpl")
    clip_tmpl.inputs.in_file = _pop(tpl_target_path)

    # set INU bspline grid based on voxel size
    bspline_grid = pe.Node(niu.Function(function=_bspline_grid),
                           name="bspline_grid")

    # INU correction of the target image
    init_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=False,
            copy_header=True,
            n_iterations=[50] * (4 - debug),
            convergence_threshold=1e-7,
            shrink_factor=4,
            rescale_intensities=True,
        ),
        n_procs=omp_nthreads,
        name="init_n4",
    )
    clip_inu = pe.Node(IntensityClip(p_min=1, p_max=99.8), name="clip_inu")

    # Create a buffer interface as a cache for the actual inputs to registration
    buffernode = pe.Node(niu.IdentityInterface(fields=["hires_target"]),
                         name="buffernode")

    # Merge image nodes
    mrg_target = pe.Node(niu.Merge(2), name="mrg_target")
    mrg_tmpl = pe.Node(niu.Merge(2), name="mrg_tmpl")

    # fmt: off
    wf.connect([
        # Target image massaging
        (inputnode, denoise, [(("in_files", _pop), "input_image")]),
        (inputnode, bspline_grid, [(("in_files", _pop), "in_file")]),
        (bspline_grid, init_n4, [("out", "args")]),
        (denoise, clip_target, [("output_image", "in_file")]),
        (clip_target, init_n4, [("out_file", "input_image")]),
        (init_n4, clip_inu, [("output_image", "in_file")]),
        (clip_inu, target_sigma, [("out_file", "in_file")]),
        (clip_inu, buffernode, [("out_file", "hires_target")]),
        (buffernode, lap_target, [("hires_target", "op1")]),
        (target_sigma, lap_target, [("out", "op2")]),
        (lap_target, norm_lap_target, [("output_image", "in_file")]),
        (buffernode, mrg_target, [("hires_target", "in1")]),
        (norm_lap_target, mrg_target, [("out", "in2")]),
        # Template massaging
        (clip_tmpl, res_tmpl, [("out_file", "in_file")]),
        (res_tmpl, tmpl_sigma, [("out_file", "in_file")]),
        (res_tmpl, lap_tmpl, [("out_file", "op1")]),
        (tmpl_sigma, lap_tmpl, [("out", "op2")]),
        (lap_tmpl, norm_lap_tmpl, [("output_image", "in_file")]),
        (res_tmpl, mrg_tmpl, [("out_file", "in1")]),
        (norm_lap_tmpl, mrg_tmpl, [("out", "in2")]),
        # Setup inputs to spatial normalization
        (mrg_target, norm, [("out", "moving_image")]),
        (mrg_tmpl, norm, [("out", "fixed_image")]),
    ])
    # fmt: on

    # Graft a template registration-mask if present
    if tpl_regmask_path:
        hires_mask = pe.Node(
            ApplyTransforms(
                input_image=_pop(tpl_regmask_path),
                transforms="identity",
                interpolation="Gaussian",
                float=True,
            ),
            name="hires_mask",
            mem_gb=1,
        )

        # fmt: off
        wf.connect([
            (res_tmpl, hires_mask, [("out_file", "reference_image")]),
            (hires_mask, norm, [("output_image", "fixed_image_masks")]),
        ])
        # fmt: on

    # Finally project brain mask and refine INU correction
    map_brainmask = pe.Node(
        ApplyTransforms(interpolation="Gaussian", float=True),
        name="map_brainmask",
        mem_gb=1,
    )
    map_brainmask.inputs.input_image = str(tpl_brainmask_path)

    thr_brainmask = pe.Node(Binarize(thresh_low=0.50), name="thr_brainmask")

    final_n4 = pe.Node(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=True,
            copy_header=True,
            n_iterations=[50] * 4,
            convergence_threshold=1e-7,
            rescale_intensities=True,
            shrink_factor=4,
        ),
        n_procs=omp_nthreads,
        name="final_n4",
    )
    final_mask = pe.Node(ApplyMask(), name="final_mask")

    # fmt: off
    wf.connect([
        (inputnode, map_brainmask, [(("in_files", _pop), "reference_image")]),
        (bspline_grid, final_n4, [("out", "args")]),
        (denoise, final_n4, [("output_image", "input_image")]),
        # Project template's brainmask into subject space
        (norm, map_brainmask, [("reverse_transforms", "transforms"),
                               ("reverse_invert_flags",
                                "invert_transform_flags")]),
        (map_brainmask, thr_brainmask, [("output_image", "in_file")]),
        # take a second pass of N4
        (map_brainmask, final_n4, [("output_image", "mask_image")]),
        (final_n4, final_mask, [("output_image", "in_file")]),
        (thr_brainmask, final_mask, [("out_mask", "in_mask")]),
        (final_n4, outputnode, [("output_image", "out_corrected")]),
        (thr_brainmask, outputnode, [("out_mask", "out_mask")]),
        (final_mask, outputnode, [("out_file", "out_brain")]),
    ])
    # fmt: on

    if interim_checkpoints:
        final_apply = pe.Node(
            ApplyTransforms(interpolation="BSpline", float=True),
            name="final_apply",
            mem_gb=1,
        )
        final_report = pe.Node(
            SimpleBeforeAfter(after_label="target",
                              before_label=f"tpl-{template_id}"),
            name="final_report",
        )
        # fmt: off
        wf.connect([
            (inputnode, final_apply, [(("in_files", _pop), "reference_image")
                                      ]),
            (res_tmpl, final_apply, [("out_file", "input_image")]),
            (norm, final_apply, [("reverse_transforms", "transforms"),
                                 ("reverse_invert_flags",
                                  "invert_transform_flags")]),
            (final_apply, final_report, [("output_image", "before")]),
            (outputnode, final_report, [("out_corrected", "after"),
                                        ("out_mask", "wm_seg")]),
        ])
        # fmt: on

    if ants_affine_init:
        # Initialize transforms with antsAI
        lowres_tmpl = pe.Node(RegridToZooms(zooms=LOWRES_ZOOMS, smooth=True),
                              name="lowres_tmpl")
        lowres_trgt = pe.Node(RegridToZooms(zooms=LOWRES_ZOOMS, smooth=True),
                              name="lowres_trgt")

        init_aff = pe.Node(
            AI(
                convergence=(100, 1e-6, 10),
                metric=("Mattes", 32, "Random", 0.25),
                principal_axes=False,
                search_factor=(factor, arc),
                search_grid=(step, grid),
                transform=("Affine", 0.1),
                verbose=True,
            ),
            name="init_aff",
            n_procs=omp_nthreads,
        )
        # fmt: off
        wf.connect([
            (clip_inu, lowres_trgt, [("out_file", "in_file")]),
            (lowres_trgt, init_aff, [("out_file", "moving_image")]),
            (clip_tmpl, lowres_tmpl, [("out_file", "in_file")]),
            (lowres_tmpl, init_aff, [("out_file", "fixed_image")]),
            (init_aff, norm, [("output_transform", "initial_moving_transform")
                              ]),
        ])
        # fmt: on

        if tpl_regmask_path:
            lowres_mask = pe.Node(
                ApplyTransforms(
                    input_image=_pop(tpl_regmask_path),
                    transforms="identity",
                    interpolation="MultiLabel",
                ),
                name="lowres_mask",
                mem_gb=1,
            )
            # fmt: off
            wf.connect([
                (lowres_tmpl, lowres_mask, [("out_file", "reference_image")]),
                (lowres_mask, init_aff, [("output_image", "fixed_image_mask")
                                         ]),
            ])
            # fmt: on

        if interim_checkpoints:
            init_apply = pe.Node(
                ApplyTransforms(interpolation="BSpline",
                                invert_transform_flags=[True]),
                name="init_apply",
                mem_gb=1,
            )
            init_mask = pe.Node(
                ApplyTransforms(interpolation="Gaussian",
                                invert_transform_flags=[True]),
                name="init_mask",
                mem_gb=1,
            )
            init_mask.inputs.input_image = str(tpl_brainmask_path)
            init_report = pe.Node(
                SimpleBeforeAfter(
                    out_report="init_report.svg",
                    before_label="target",
                    after_label="template",
                ),
                name="init_report",
            )
            # fmt: off
            wf.connect([
                (lowres_trgt, init_apply, [("out_file", "reference_image")]),
                (lowres_tmpl, init_apply, [("out_file", "input_image")]),
                (init_aff, init_apply, [("output_transform", "transforms")]),
                (lowres_trgt, init_report, [("out_file", "before")]),
                (init_apply, init_report, [("output_image", "after")]),
                (lowres_trgt, init_mask, [("out_file", "reference_image")]),
                (init_aff, init_mask, [("output_transform", "transforms")]),
                (init_mask, init_report, [("output_image", "wm_seg")]),
            ])
            # fmt: on
    else:
        norm.inputs.initial_moving_transform_com = 1

    if output_dir:
        ds_final_inu = pe.Node(DerivativesDataSink(
            base_directory=str(output_dir),
            desc="preproc",
            compress=True,
        ),
                               name="ds_final_inu",
                               run_without_submitting=True)
        ds_final_msk = pe.Node(DerivativesDataSink(
            base_directory=str(output_dir),
            desc="brain",
            suffix="mask",
            compress=True,
        ),
                               name="ds_final_msk",
                               run_without_submitting=True)

        # fmt: off
        wf.connect([
            (inputnode, ds_final_inu, [("in_files", "source_file")]),
            (inputnode, ds_final_msk, [("in_files", "source_file")]),
            (outputnode, ds_final_inu, [("out_corrected", "in_file")]),
            (outputnode, ds_final_msk, [("out_mask", "in_file")]),
        ])
        # fmt: on

        if interim_checkpoints:
            ds_report = pe.Node(DerivativesDataSink(
                base_directory=str(output_dir),
                desc="brain",
                suffix="mask",
                datatype="figures"),
                                name="ds_report",
                                run_without_submitting=True)
            # fmt: off
            wf.connect([
                (inputnode, ds_report, [("in_files", "source_file")]),
                (final_report, ds_report, [("out_report", "in_file")]),
            ])
            # fmt: on

        if ants_affine_init and interim_checkpoints:
            ds_report_init = pe.Node(DerivativesDataSink(
                base_directory=str(output_dir),
                desc="init",
                suffix="mask",
                datatype="figures"),
                                     name="ds_report_init",
                                     run_without_submitting=True)
            # fmt: off
            wf.connect([
                (inputnode, ds_report_init, [("in_files", "source_file")]),
                (init_report, ds_report_init, [("out_report", "in_file")]),
            ])
            # fmt: on

    return wf
Beispiel #4
0
def init_anat_reports_wf(*, freesurfer, output_dir, sloppy, name="anat_reports_wf"):
    """
    Patched workflow for reports to allow no resolution for templates
    Set up a battery of datasinks to store reports in the right location.
    Parameters
    ----------
    freesurfer : :obj:`bool`
        FreeSurfer was enabled
    output_dir : :obj:`str`
        Directory in which to save derivatives
    name : :obj:`str`
        Workflow name (default: anat_reports_wf)
    Inputs
    ------
    source_file
        Input T1w image
    std_t1w
        T1w image resampled to standard space
    std_mask
        Mask of skull-stripped template
    subject_dir
        FreeSurfer SUBJECTS_DIR
    subject_id
        FreeSurfer subject ID
    t1w_conform_report
        Conformation report
    t1w_preproc
        The T1w reference map, which is calculated as the average of bias-corrected
        and preprocessed T1w images, defining the anatomical space.
    t1w_dseg
        Segmentation in T1w space
    t1w_mask
        Brain (binary) mask estimated by brain extraction.
    template
        Template space and specifications
    """
    from niworkflows.interfaces.reportlets.registration import (
        SimpleBeforeAfterRPT as SimpleBeforeAfter,
    )
    from niworkflows.interfaces.reportlets.masks import ROIsPlot
    from smriprep.interfaces.templateflow import TemplateFlowSelect
    from smriprep.workflows.outputs import (
        _fmt, _empty_report, _rpt_masks, _drop_cohort, _pick_cohort,
    )
    from ...utils.patches import set_tf_resolution

    workflow = Workflow(name=name)

    inputfields = [
        "source_file",
        "t1w_conform_report",
        "t1w_preproc",
        "t1w_dseg",
        "t1w_mask",
        "template",
        "std_t1w",
        "std_mask",
        "subject_id",
        "subjects_dir",
    ]
    inputnode = pe.Node(niu.IdentityInterface(fields=inputfields), name="inputnode")

    seg_rpt = pe.Node(
        ROIsPlot(colors=["b", "magenta"], levels=[1.5, 2.5]), name="seg_rpt"
    )

    t1w_conform_check = pe.Node(
        niu.Function(function=_empty_report),
        name="t1w_conform_check",
        run_without_submitting=True,
    )

    ds_t1w_conform_report = pe.Node(
        DerivativesDataSink(
            base_directory=output_dir, desc="conform", datatype="figures"
        ),
        name="ds_t1w_conform_report",
        run_without_submitting=True,
    )

    ds_t1w_dseg_mask_report = pe.Node(
        DerivativesDataSink(
            base_directory=output_dir, suffix="dseg", datatype="figures"
        ),
        name="ds_t1w_dseg_mask_report",
        run_without_submitting=True,
    )

    # fmt:off
    workflow.connect([
        (inputnode, t1w_conform_check, [('t1w_conform_report', 'in_file')]),
        (t1w_conform_check, ds_t1w_conform_report, [('out', 'in_file')]),
        (inputnode, ds_t1w_conform_report, [('source_file', 'source_file')]),
        (inputnode, ds_t1w_dseg_mask_report, [('source_file', 'source_file')]),
        (inputnode, seg_rpt, [('t1w_preproc', 'in_file'),
                              ('t1w_mask', 'in_mask'),
                              ('t1w_dseg', 'in_rois')]),
        (seg_rpt, ds_t1w_dseg_mask_report, [('out_report', 'in_file')]),
    ])
    # fmt:on

    # Generate reportlets showing spatial normalization
    tf_select = pe.Node(
        TemplateFlowSelect(), name="tf_select", run_without_submitting=True
    )

    set_tf_res = pe.Node(niu.Function(function=set_tf_resolution), name='set_tf_res')
    set_tf_res.inputs.sloppy = sloppy

    norm_msk = pe.Node(
        niu.Function(
            function=_rpt_masks,
            output_names=["before", "after"],
            input_names=["mask_file", "before", "after", "after_mask"],
        ),
        name="norm_msk",
    )
    norm_rpt = pe.Node(SimpleBeforeAfter(), name="norm_rpt", mem_gb=0.1)
    norm_rpt.inputs.after_label = "Participant"  # after

    ds_std_t1w_report = pe.Node(
        DerivativesDataSink(
            base_directory=output_dir, suffix="T1w", datatype="figures"
        ),
        name="ds_std_t1w_report",
        run_without_submitting=True,
    )

    # fmt:off
    workflow.connect([
        (inputnode, set_tf_res, [(('template', _drop_cohort), 'template')]),
        (set_tf_res, tf_select, [('out', 'resolution')]),
        (inputnode, tf_select, [(('template', _drop_cohort), 'template'),
                                (('template', _pick_cohort), 'cohort')]),
        (inputnode, norm_rpt, [('template', 'before_label')]),
        (inputnode, norm_msk, [('std_t1w', 'after'),
                               ('std_mask', 'after_mask')]),
        (tf_select, norm_msk, [('t1w_file', 'before'),
                               ('brain_mask', 'mask_file')]),
        (norm_msk, norm_rpt, [('before', 'before'),
                              ('after', 'after')]),
        (inputnode, ds_std_t1w_report, [
            (('template', _fmt), 'space'),
            ('source_file', 'source_file')]),
        (norm_rpt, ds_std_t1w_report, [('out_report', 'in_file')]),
    ])
    # fmt:on

    if freesurfer:
        from smriprep.interfaces.reports import FSSurfaceReport

        recon_report = pe.Node(FSSurfaceReport(), name="recon_report")
        recon_report.interface._always_run = True

        ds_recon_report = pe.Node(
            DerivativesDataSink(
                base_directory=output_dir, desc="reconall", datatype="figures"
            ),
            name="ds_recon_report",
            run_without_submitting=True,
        )
        # fmt:off
        workflow.connect([
            (inputnode, recon_report, [('subjects_dir', 'subjects_dir'),
                                       ('subject_id', 'subject_id')]),
            (recon_report, ds_recon_report, [('out_report', 'in_file')]),
            (inputnode, ds_recon_report, [('source_file', 'source_file')])
        ])
        # fmt:on

    return workflow
Beispiel #5
0
def init_coreg_report_wf(*, output_dir, name="coreg_report_wf"):
    """
    Generate and store a report in the right location.

    Parameters
    ----------
    output_dir : :obj:`str`
        Directory in which to save derivatives
    name : :obj:`str`
        Workflow name (default: coreg_report_wf)

    Inputs
    ------
    source_file
        Input reference T1w image
    t1w_preproc
        Preprocessed T1w image.
    t2w_preproc
        Preprocessed T2w image, aligned with the T1w image.
    in_mask
        Brain mask.

    """
    from niworkflows.interfaces.reportlets.registration import (
        SimpleBeforeAfterRPT as SimpleBeforeAfter,
    )

    workflow = Workflow(name=name)

    inputfields = [
        "source_file",
        "t1w_preproc",
        "t2w_preproc",
        "in_mask",
    ]
    inputnode = pe.Node(niu.IdentityInterface(fields=inputfields), name="inputnode")
    # Generate reportlets showing spatial normalization
    norm_rpt = pe.Node(
        SimpleBeforeAfter(before_label="T2w", after_label="T1w"),
        name="norm_rpt",
        mem_gb=0.1,
    )

    ds_t1w_t2w_report = pe.Node(
        DerivativesDataSink(
            base_directory=output_dir, space="T2w", suffix="T1w", datatype="figures"
        ),
        name="ds_t1w_t2w_report",
        run_without_submitting=True,
    )

    # fmt:off
    workflow.connect([
        (inputnode, norm_rpt, [("t2w_preproc", "before"),
                               ("t1w_preproc", "after"),
                               ("in_mask", "wm_seg")]),
        (inputnode, ds_t1w_t2w_report, [("source_file", "source_file")]),
        (norm_rpt, ds_t1w_t2w_report, [("out_report", "in_file")]),
    ])
    # fmt:on

    return workflow
def init_bold_preproc_report_wf(mem_gb,
                                reportlets_dir,
                                name="bold_preproc_report_wf"):
    """
    Generate a visual report.

    This workflow generates and saves a reportlet showing the effect of resampling
    the BOLD signal using the standard deviation maps.

    Workflow Graph
        .. workflow::
            :graph2use: orig
            :simple_form: yes

            from fprodents.workflows.bold.resampling import init_bold_preproc_report_wf
            wf = init_bold_preproc_report_wf(mem_gb=1, reportlets_dir='.')

    Parameters
    ----------
    mem_gb : :obj:`float`
        Size of BOLD file in GB
    reportlets_dir : :obj:`str`
        Directory in which to save reportlets
    name : :obj:`str`, optional
        Workflow name (default: bold_preproc_report_wf)

    Inputs
    ------
    in_pre
        BOLD time-series, before resampling
    in_post
        BOLD time-series, after resampling
    name_source
        BOLD series NIfTI file
        Used to recover original information lost during processing

    """
    from nipype.algorithms.confounds import TSNR
    from niworkflows.engine.workflows import LiterateWorkflow as Workflow
    from niworkflows.interfaces.reportlets.registration import (
        SimpleBeforeAfterRPT as SimpleBeforeAfter)
    from ...interfaces import DerivativesDataSink

    workflow = Workflow(name=name)

    inputnode = pe.Node(
        niu.IdentityInterface(fields=["in_pre", "in_post", "name_source"]),
        name="inputnode",
    )

    pre_tsnr = pe.Node(TSNR(), name="pre_tsnr", mem_gb=mem_gb * 4.5)
    pos_tsnr = pe.Node(TSNR(), name="pos_tsnr", mem_gb=mem_gb * 4.5)

    bold_rpt = pe.Node(SimpleBeforeAfter(), name="bold_rpt", mem_gb=0.1)
    ds_report_bold = pe.Node(
        DerivativesDataSink(
            base_directory=reportlets_dir,
            desc="preproc",
            datatype="figures",
            dismiss_entities=("echo", ),
        ),
        name="ds_report_bold",
        mem_gb=DEFAULT_MEMORY_MIN_GB,
        run_without_submitting=True,
    )

    # fmt:off
    workflow.connect([
        (inputnode, ds_report_bold, [('name_source', 'source_file')]),
        (inputnode, pre_tsnr, [('in_pre', 'in_file')]),
        (inputnode, pos_tsnr, [('in_post', 'in_file')]),
        (pre_tsnr, bold_rpt, [('stddev_file', 'before')]),
        (pos_tsnr, bold_rpt, [('stddev_file', 'after')]),
        (bold_rpt, ds_report_bold, [('out_report', 'in_file')]),
    ])
    # fmt:on

    return workflow
Beispiel #7
0
def init_func_preproc_wf(bold_file, has_fieldmap=False):
    """
    This workflow controls the functional preprocessing stages of *fMRIPrep*.

    Workflow Graph
        .. workflow::
            :graph2use: orig
            :simple_form: yes

            from fmriprep.workflows.tests import mock_config
            from fmriprep import config
            from fmriprep.workflows.bold.base import init_func_preproc_wf
            with mock_config():
                bold_file = config.execution.bids_dir / 'sub-01' / 'func' \
                    / 'sub-01_task-mixedgamblestask_run-01_bold.nii.gz'
                wf = init_func_preproc_wf(str(bold_file))

    Parameters
    ----------
    bold_file
        BOLD series NIfTI file
    has_fieldmap
        Signals the workflow to use inputnode fieldmap files

    Inputs
    ------
    bold_file
        BOLD series NIfTI file
    t1w_preproc
        Bias-corrected structural template image
    t1w_mask
        Mask of the skull-stripped template image
    t1w_dseg
        Segmentation of preprocessed structural image, including
        gray-matter (GM), white-matter (WM) and cerebrospinal fluid (CSF)
    t1w_asec
        Segmentation of structural image, done with FreeSurfer.
    t1w_aparc
        Parcellation of structural image, done with FreeSurfer.
    t1w_tpms
        List of tissue probability maps in T1w space
    template
        List of templates to target
    anat2std_xfm
        List of transform files, collated with templates
    std2anat_xfm
        List of inverse transform files, collated with templates
    subjects_dir
        FreeSurfer SUBJECTS_DIR
    subject_id
        FreeSurfer subject ID
    t1w2fsnative_xfm
        LTA-style affine matrix translating from T1w to FreeSurfer-conformed subject space
    fsnative2t1w_xfm
        LTA-style affine matrix translating from FreeSurfer-conformed subject space to T1w
    bold_ref
        BOLD reference file
    bold_ref_xfm
        Transform file in LTA format from bold to reference
    n_dummy_scans
        Number of nonsteady states at the beginning of the BOLD run

    Outputs
    -------
    bold_t1
        BOLD series, resampled to T1w space
    bold_mask_t1
        BOLD series mask in T1w space
    bold_std
        BOLD series, resampled to template space
    bold_mask_std
        BOLD series mask in template space
    confounds
        TSV of confounds
    surfaces
        BOLD series, resampled to FreeSurfer surfaces
    aroma_noise_ics
        Noise components identified by ICA-AROMA
    melodic_mix
        FSL MELODIC mixing matrix
    bold_cifti
        BOLD CIFTI image
    cifti_variant
        combination of target spaces for `bold_cifti`

    See Also
    --------

    * :py:func:`~niworkflows.func.util.init_bold_reference_wf`
    * :py:func:`~fmriprep.workflows.bold.stc.init_bold_stc_wf`
    * :py:func:`~fmriprep.workflows.bold.hmc.init_bold_hmc_wf`
    * :py:func:`~fmriprep.workflows.bold.t2s.init_bold_t2s_wf`
    * :py:func:`~fmriprep.workflows.bold.registration.init_bold_t1_trans_wf`
    * :py:func:`~fmriprep.workflows.bold.registration.init_bold_reg_wf`
    * :py:func:`~fmriprep.workflows.bold.confounds.init_bold_confs_wf`
    * :py:func:`~fmriprep.workflows.bold.confounds.init_ica_aroma_wf`
    * :py:func:`~fmriprep.workflows.bold.resampling.init_bold_std_trans_wf`
    * :py:func:`~fmriprep.workflows.bold.resampling.init_bold_preproc_trans_wf`
    * :py:func:`~fmriprep.workflows.bold.resampling.init_bold_surf_wf`
    * :py:func:`~sdcflows.workflows.fmap.init_fmap_wf`
    * :py:func:`~sdcflows.workflows.pepolar.init_pepolar_unwarp_wf`
    * :py:func:`~sdcflows.workflows.phdiff.init_phdiff_wf`
    * :py:func:`~sdcflows.workflows.syn.init_syn_sdc_wf`
    * :py:func:`~sdcflows.workflows.unwarp.init_sdc_unwarp_wf`

    """
    from niworkflows.engine.workflows import LiterateWorkflow as Workflow
    from niworkflows.interfaces.utility import DictMerge

    mem_gb = {'filesize': 1, 'resampled': 1, 'largemem': 1}
    bold_tlen = 10

    # Have some options handy
    omp_nthreads = config.nipype.omp_nthreads
    freesurfer = config.workflow.run_reconall
    spaces = config.workflow.spaces
    nibabies_dir = str(config.execution.nibabies_dir)

    # Extract BIDS entities and metadata from BOLD file(s)
    entities = extract_entities(bold_file)
    layout = config.execution.layout

    # Take first file as reference
    ref_file = pop_file(bold_file)
    metadata = layout.get_metadata(ref_file)
    # get original image orientation
    ref_orientation = get_img_orientation(ref_file)

    echo_idxs = listify(entities.get("echo", []))
    multiecho = len(echo_idxs) > 2
    if len(echo_idxs) == 1:
        config.loggers.workflow.warning(
            f"Running a single echo <{ref_file}> from a seemingly multi-echo dataset."
        )
        bold_file = ref_file  # Just in case - drop the list

    if len(echo_idxs) == 2:
        raise RuntimeError(
            "Multi-echo processing requires at least three different echos (found two)."
        )

    if multiecho:
        # Drop echo entity for future queries, have a boolean shorthand
        entities.pop("echo", None)
        # reorder echoes from shortest to largest
        tes, bold_file = zip(*sorted([(layout.get_metadata(bf)["EchoTime"], bf)
                                      for bf in bold_file]))
        ref_file = bold_file[0]  # Reset reference to be the shortest TE

    if os.path.isfile(ref_file):
        bold_tlen, mem_gb = _create_mem_gb(ref_file)

    wf_name = _get_wf_name(ref_file)
    config.loggers.workflow.debug(
        f'Creating bold processing workflow for <{ref_file}> ({mem_gb["filesize"]:.2f} GB '
        f'/ {bold_tlen} TRs). Memory resampled/largemem={mem_gb["resampled"]:.2f}'
        f'/{mem_gb["largemem"]:.2f} GB.')

    # Find associated sbref, if possible
    entities['suffix'] = 'sbref'
    entities['extension'] = ['.nii', '.nii.gz']  # Overwrite extensions
    sbref_files = layout.get(return_type='file', **entities)

    sbref_msg = f"No single-band-reference found for {os.path.basename(ref_file)}."
    if sbref_files and 'sbref' in config.workflow.ignore:
        sbref_msg = "Single-band reference file(s) found and ignored."
    elif sbref_files:
        sbref_msg = "Using single-band reference file(s) {}.".format(','.join(
            [os.path.basename(sbf) for sbf in sbref_files]))
    config.loggers.workflow.info(sbref_msg)

    if has_fieldmap:
        # Search for intended fieldmap
        from pathlib import Path
        import re
        from sdcflows.fieldmaps import get_identifier

        bold_rel = re.sub(r"^sub-[a-zA-Z0-9]*/", "",
                          str(Path(bold_file).relative_to(layout.root)))
        estimator_key = get_identifier(bold_rel)
        if not estimator_key:
            has_fieldmap = False
            config.loggers.workflow.critical(
                f"None of the available B0 fieldmaps are associated to <{bold_rel}>"
            )

    # Short circuits: (True and True and (False or 'TooShort')) == 'TooShort'
    run_stc = (bool(metadata.get("SliceTiming"))
               and 'slicetiming' not in config.workflow.ignore
               and (_get_series_len(ref_file) > 4 or "TooShort"))

    # Build workflow
    workflow = Workflow(name=wf_name)
    workflow.__postdesc__ = """\
All resamplings can be performed with *a single interpolation
step* by composing all the pertinent transformations (i.e. head-motion
transform matrices, susceptibility distortion correction when available,
and co-registrations to anatomical and output spaces).
Gridded (volumetric) resamplings were performed using `antsApplyTransforms` (ANTs),
configured with Lanczos interpolation to minimize the smoothing
effects of other kernels [@lanczos].
Non-gridded (surface) resamplings were performed using `mri_vol2surf`
(FreeSurfer).
"""

    inputnode = pe.Node(
        niu.IdentityInterface(fields=[
            'bold_file',
            # from smriprep
            'anat_preproc',
            'anat_brain',
            'anat_mask',
            'anat_dseg',
            'anat_tpms',
            'anat_aseg',
            'anat_aparc',
            'anat2std_xfm',
            'std2anat_xfm',
            'template',
            # from bold reference workflow
            'bold_ref',
            'bold_ref_xfm',
            'n_dummy_scans',
            # from sdcflows (optional)
            'fmap',
            'fmap_ref',
            'fmap_coeff',
            'fmap_mask',
            'fmap_id',
            # if reconstructing with FreeSurfer (optional)
            'anat2fsnative_xfm',
            'fsnative2anat_xfm',
            'subject_id',
            'subjects_dir',
        ]),
        name='inputnode')
    inputnode.inputs.bold_file = bold_file

    outputnode = pe.Node(niu.IdentityInterface(fields=[
        'bold_anat', 'bold_anat_ref', 'bold2anat_xfm', 'anat2bold_xfm',
        'bold_mask_anat', 'bold_aseg_anat', 'bold_aparc_anat', 'bold_std',
        'bold_std_ref', 'bold_mask_std', 'bold_aseg_std', 'bold_aparc_std',
        'bold_native', 'bold_cifti', 'cifti_variant', 'cifti_metadata',
        'cifti_density', 'surfaces', 'confounds', 'aroma_noise_ics',
        'melodic_mix', 'nonaggr_denoised_file', 'confounds_metadata'
    ]),
                         name='outputnode')

    # BOLD buffer: an identity used as a pointer to either the original BOLD
    # or the STC'ed one for further use.
    boldbuffer = pe.Node(niu.IdentityInterface(fields=['bold_file']),
                         name='boldbuffer')

    summary = pe.Node(FunctionalSummary(
        slice_timing=run_stc,
        registration=('FSL', 'FreeSurfer')[freesurfer],
        registration_dof=config.workflow.bold2t1w_dof,
        registration_init=config.workflow.bold2t1w_init,
        pe_direction=metadata.get("PhaseEncodingDirection"),
        echo_idx=echo_idxs,
        tr=metadata.get("RepetitionTime"),
        orientation=ref_orientation),
                      name='summary',
                      mem_gb=config.DEFAULT_MEMORY_MIN_GB,
                      run_without_submitting=True)
    summary.inputs.dummy_scans = config.workflow.dummy_scans
    # TODO: SDC: make dynamic
    summary.inputs.distortion_correction = 'None' if not has_fieldmap else 'TOPUP'

    func_derivatives_wf = init_func_derivatives_wf(
        bids_root=layout.root,
        cifti_output=config.workflow.cifti_output,
        freesurfer=freesurfer,
        metadata=metadata,
        output_dir=nibabies_dir,
        spaces=spaces,
        use_aroma=config.workflow.use_aroma,
        debug=config.execution.debug,
    )

    workflow.connect([
        (outputnode, func_derivatives_wf, [
            ('bold_anat', 'inputnode.bold_t1'),
            ('bold_anat_ref', 'inputnode.bold_t1_ref'),
            ('bold2anat_xfm', 'inputnode.bold2anat_xfm'),
            ('anat2bold_xfm', 'inputnode.anat2bold_xfm'),
            ('bold_aseg_anat', 'inputnode.bold_aseg_t1'),
            ('bold_aparc_anat', 'inputnode.bold_aparc_t1'),
            ('bold_mask_anat', 'inputnode.bold_mask_t1'),
            ('bold_native', 'inputnode.bold_native'),
            ('confounds', 'inputnode.confounds'),
            ('surfaces', 'inputnode.surf_files'),
            ('aroma_noise_ics', 'inputnode.aroma_noise_ics'),
            ('melodic_mix', 'inputnode.melodic_mix'),
            ('nonaggr_denoised_file', 'inputnode.nonaggr_denoised_file'),
            ('bold_cifti', 'inputnode.bold_cifti'),
            ('cifti_variant', 'inputnode.cifti_variant'),
            ('cifti_metadata', 'inputnode.cifti_metadata'),
            ('cifti_density', 'inputnode.cifti_density'),
            ('confounds_metadata', 'inputnode.confounds_metadata'),
            ('acompcor_masks', 'inputnode.acompcor_masks'),
            ('tcompcor_mask', 'inputnode.tcompcor_mask'),
        ]),
    ])

    # Extract BOLD validation from init_bold_reference_wf
    val_bold = pe.MapNode(
        ValidateImage(),
        name="val_bold",
        mem_gb=config.DEFAULT_MEMORY_MIN_GB,
        iterfield=["in_file"],
    )
    val_bold.inputs.in_file = listify(bold_file)

    # Top-level BOLD splitter
    bold_split = pe.Node(FSLSplit(dimension='t'),
                         name='bold_split',
                         mem_gb=mem_gb['filesize'] * 3)

    # HMC on the BOLD
    bold_hmc_wf = init_bold_hmc_wf(name='bold_hmc_wf',
                                   mem_gb=mem_gb['filesize'],
                                   omp_nthreads=omp_nthreads)

    # calculate BOLD registration to T1w
    bold_reg_wf = init_bold_reg_wf(
        bold2t1w_dof=config.workflow.bold2t1w_dof,
        bold2t1w_init=config.workflow.bold2t1w_init,
        freesurfer=freesurfer,
        mem_gb=mem_gb['resampled'],
        name='bold_reg_wf',
        omp_nthreads=omp_nthreads,
        sloppy=config.execution.sloppy,
        use_bbr=config.workflow.use_bbr,
    )

    # apply BOLD registration to T1w
    bold_t1_trans_wf = init_bold_t1_trans_wf(name='bold_t1_trans_wf',
                                             freesurfer=freesurfer,
                                             mem_gb=mem_gb['resampled'],
                                             omp_nthreads=omp_nthreads,
                                             use_compression=False)
    if not has_fieldmap:
        bold_t1_trans_wf.inputs.inputnode.fieldwarp = 'identity'

    # get confounds
    bold_confounds_wf = init_bold_confs_wf(
        mem_gb=mem_gb['largemem'],
        metadata=metadata,
        freesurfer=freesurfer,
        regressors_all_comps=config.workflow.regressors_all_comps,
        regressors_fd_th=config.workflow.regressors_fd_th,
        regressors_dvars_th=config.workflow.regressors_dvars_th,
        name='bold_confounds_wf')
    bold_confounds_wf.get_node('inputnode').inputs.t1_transform_flags = [False]

    # Apply transforms in 1 shot
    # Only use uncompressed output if AROMA is to be run
    bold_bold_trans_wf = init_bold_preproc_trans_wf(
        mem_gb=mem_gb['resampled'],
        omp_nthreads=omp_nthreads,
        use_compression=not config.execution.low_mem,
        use_fieldwarp=False,  # TODO: Fieldwarp is already applied in new sdcflow
        name='bold_bold_trans_wf')
    bold_bold_trans_wf.inputs.inputnode.name_source = ref_file

    # SLICE-TIME CORRECTION (or bypass) #############################################
    if run_stc is True:  # bool('TooShort') == True, so check True explicitly
        bold_stc_wf = init_bold_stc_wf(name='bold_stc_wf', metadata=metadata)
        workflow.connect([
            (inputnode, bold_stc_wf, [('n_dummy_scans', 'inputnode.skip_vols')
                                      ]),
            (bold_stc_wf, boldbuffer, [('outputnode.stc_file', 'bold_file')]),
        ])
        if not multiecho:
            workflow.connect([(val_bold, bold_stc_wf, [
                (("out_file", pop_file), 'inputnode.bold_file')
            ])])
        else:  # for meepi, iterate through stc_wf for all workflows
            meepi_echos = boldbuffer.clone(name='meepi_echos')
            meepi_echos.iterables = ('bold_file', bold_file)
            workflow.connect([(meepi_echos, bold_stc_wf,
                               [('bold_file', 'inputnode.bold_file')])])
    elif not multiecho:  # STC is too short or False
        # bypass STC from original BOLD to the splitter through boldbuffer
        workflow.connect([(val_bold, boldbuffer, [(("out_file", pop_file),
                                                   'bold_file')])])
    else:
        # for meepi, iterate over all meepi echos to boldbuffer
        boldbuffer.iterables = ('bold_file', bold_file)

    # MULTI-ECHO EPI DATA #############################################
    if multiecho:  # instantiate relevant interfaces, imports
        from niworkflows.func.util import init_skullstrip_bold_wf
        skullstrip_bold_wf = init_skullstrip_bold_wf(name='skullstrip_bold_wf')

        split_opt_comb = bold_split.clone(name='split_opt_comb')

        inputnode.inputs.bold_file = ref_file  # Replace reference w first echo

        join_echos = pe.JoinNode(
            niu.IdentityInterface(
                fields=['bold_files', 'skullstripped_bold_files']),
            joinsource=('meepi_echos' if run_stc is True else 'boldbuffer'),
            joinfield=['bold_files', 'skullstripped_bold_files'],
            name='join_echos')

        # create optimal combination, adaptive T2* map
        bold_t2s_wf = init_bold_t2s_wf(echo_times=tes,
                                       mem_gb=mem_gb['resampled'],
                                       omp_nthreads=omp_nthreads,
                                       name='bold_t2smap_wf')

    # Mask BOLD reference image
    final_boldref_masker = pe.Node(BrainExtraction(),
                                   name='final_boldref_masker')

    # MAIN WORKFLOW STRUCTURE #######################################################
    workflow.connect([
        # BOLD buffer has slice-time corrected if it was run, original otherwise
        (boldbuffer, bold_split, [('bold_file', 'in_file')]),
        # HMC
        (inputnode, bold_hmc_wf, [('bold_ref', 'inputnode.raw_ref_image')]),
        (inputnode, final_boldref_masker, [('bold_ref', 'in_file')]),
        (val_bold, bold_hmc_wf, [(("out_file", pop_file),
                                  'inputnode.bold_file')]),
        (inputnode, summary, [('n_dummy_scans', 'algo_dummy_scans')]),
        # EPI-T1 registration workflow
        (
            inputnode,
            bold_reg_wf,
            [
                ('anat_dseg', 'inputnode.t1w_dseg'),
                # Undefined if --fs-no-reconall, but this is safe
                ('subjects_dir', 'inputnode.subjects_dir'),
                ('subject_id', 'inputnode.subject_id'),
                ('fsnative2anat_xfm', 'inputnode.fsnative2t1w_xfm')
            ]),
        (inputnode, bold_reg_wf, [('anat_brain', 'inputnode.t1w_brain')]),
        (inputnode, bold_t1_trans_wf, [('bold_file', 'inputnode.name_source'),
                                       ('anat_mask', 'inputnode.t1w_mask'),
                                       ('anat_brain', 'inputnode.t1w_brain'),
                                       ('anat_aseg', 'inputnode.t1w_aseg'),
                                       ('anat_aparc', 'inputnode.t1w_aparc')]),
        (bold_reg_wf, outputnode,
         [('outputnode.itk_bold_to_t1', 'bold2anat_xfm'),
          ('outputnode.itk_t1_to_bold', 'anat2bold_xfm')]),
        (bold_reg_wf, bold_t1_trans_wf, [('outputnode.itk_bold_to_t1',
                                          'inputnode.itk_bold_to_t1')]),
        (bold_t1_trans_wf, outputnode,
         [('outputnode.bold_t1', 'bold_anat'),
          ('outputnode.bold_t1_ref', 'bold_anat_ref'),
          ('outputnode.bold_aseg_t1', 'bold_aseg_anat'),
          ('outputnode.bold_aparc_t1', 'bold_aparc_anat')]),
        (bold_reg_wf, summary, [('outputnode.fallback', 'fallback')]),
        # Connect bold_confounds_wf
        (inputnode, bold_confounds_wf, [('anat_tpms', 'inputnode.t1w_tpms'),
                                        ('anat_mask', 'inputnode.t1w_mask')]),
        (bold_hmc_wf, bold_confounds_wf,
         [('outputnode.movpar_file', 'inputnode.movpar_file'),
          ('outputnode.rmsd_file', 'inputnode.rmsd_file')]),
        (bold_reg_wf, bold_confounds_wf, [('outputnode.itk_t1_to_bold',
                                           'inputnode.t1_bold_xform')]),
        (inputnode, bold_confounds_wf, [('n_dummy_scans',
                                         'inputnode.skip_vols')]),
        (bold_confounds_wf, outputnode, [
            ('outputnode.confounds_file', 'confounds'),
            ('outputnode.confounds_metadata', 'confounds_metadata'),
            ('outputnode.acompcor_masks', 'acompcor_masks'),
            ('outputnode.tcompcor_mask', 'tcompcor_mask'),
        ]),
        # Connect bold_bold_trans_wf
        (bold_split, bold_bold_trans_wf, [('out_files', 'inputnode.bold_file')]
         ),
        (bold_hmc_wf, bold_bold_trans_wf, [('outputnode.xforms',
                                            'inputnode.hmc_xforms')]),
        # Summary
        (outputnode, summary, [('confounds', 'confounds_file')]),
    ])

    # for standard EPI data, pass along correct file
    if not multiecho:
        # TODO: Add SDC
        workflow.connect([
            (inputnode, func_derivatives_wf, [('bold_file',
                                               'inputnode.source_file')]),
            (bold_bold_trans_wf, bold_confounds_wf, [('outputnode.bold',
                                                      'inputnode.bold')]),
            # (bold_bold_trans_wf, final_boldref_wf, [
            #     ('outputnode.bold', 'inputnode.bold_file')]),
            (bold_split, bold_t1_trans_wf, [('out_files',
                                             'inputnode.bold_split')]),
            (bold_hmc_wf, bold_t1_trans_wf, [('outputnode.xforms',
                                              'inputnode.hmc_xforms')]),
            # (bold_sdc_wf, bold_t1_trans_wf, [
            #     ('outputnode.out_warp', 'inputnode.fieldwarp')])
        ])
    else:  # for meepi, use optimal combination
        workflow.connect([
            # update name source for optimal combination
            (inputnode, func_derivatives_wf,
             [(('bold_file', combine_meepi_source), 'inputnode.source_file')]),
            (bold_bold_trans_wf, join_echos, [('outputnode.bold', 'bold_files')
                                              ]),
            # (join_echos, final_boldref_wf, [
            #     ('bold_files', 'inputnode.bold_file')]),
            # TODO: Check with multi-echo data
            (bold_bold_trans_wf, skullstrip_bold_wf, [('outputnode.bold',
                                                       'inputnode.in_file')]),
            (skullstrip_bold_wf, join_echos,
             [('outputnode.skull_stripped_file', 'skullstripped_bold_files')]),
            (join_echos, bold_t2s_wf, [('skullstripped_bold_files',
                                        'inputnode.bold_file')]),
            (bold_t2s_wf, bold_confounds_wf, [('outputnode.bold',
                                               'inputnode.bold')]),
            (bold_t2s_wf, split_opt_comb, [('outputnode.bold', 'in_file')]),
            (split_opt_comb, bold_t1_trans_wf, [('out_files',
                                                 'inputnode.bold_split')]),
        ])

        # Already applied in bold_bold_trans_wf, which inputs to bold_t2s_wf
        bold_t1_trans_wf.inputs.inputnode.fieldwarp = 'identity'
        bold_t1_trans_wf.inputs.inputnode.hmc_xforms = 'identity'

    # Map final BOLD mask into T1w space (if required)
    nonstd_spaces = set(spaces.get_nonstandard())
    if nonstd_spaces.intersection(('T1w', 'anat')):
        from niworkflows.interfaces.fixes import (FixHeaderApplyTransforms as
                                                  ApplyTransforms)

        boldmask_to_t1w = pe.Node(ApplyTransforms(interpolation='MultiLabel'),
                                  name='boldmask_to_t1w',
                                  mem_gb=0.1)
        workflow.connect([
            (bold_reg_wf, boldmask_to_t1w, [('outputnode.itk_bold_to_t1',
                                             'transforms')]),
            (bold_t1_trans_wf, boldmask_to_t1w, [('outputnode.bold_mask_t1',
                                                  'reference_image')]),
            (boldmask_to_t1w, outputnode, [('output_image', 'bold_mask_anat')
                                           ]),
        ])

    if nonstd_spaces.intersection(('func', 'run', 'bold', 'boldref', 'sbref')):
        workflow.connect([
            (inputnode, func_derivatives_wf, [
                ('bold_ref', 'inputnode.bold_native_ref'),
            ]),
            (bold_bold_trans_wf if not multiecho else bold_t2s_wf, outputnode,
             [('outputnode.bold', 'bold_native')])
        ])

    if spaces.get_spaces(nonstandard=False, dim=(3, )):
        # Apply transforms in 1 shot
        # Only use uncompressed output if AROMA is to be run
        bold_std_trans_wf = init_bold_std_trans_wf(
            freesurfer=freesurfer,
            mem_gb=mem_gb['resampled'],
            omp_nthreads=omp_nthreads,
            spaces=spaces,
            name='bold_std_trans_wf',
            use_compression=not config.execution.low_mem,
        )
        if not has_fieldmap:
            bold_std_trans_wf.inputs.inputnode.fieldwarp = 'identity'

        workflow.connect([
            (inputnode, bold_std_trans_wf,
             [('template', 'inputnode.templates'),
              ('anat2std_xfm', 'inputnode.anat2std_xfm'),
              ('bold_file', 'inputnode.name_source'),
              ('anat_aseg', 'inputnode.bold_aseg'),
              ('anat_aparc', 'inputnode.bold_aparc')]),
            (bold_reg_wf, bold_std_trans_wf, [('outputnode.itk_bold_to_t1',
                                               'inputnode.itk_bold_to_t1')]),
            (bold_std_trans_wf, outputnode,
             [('outputnode.bold_std', 'bold_std'),
              ('outputnode.bold_std_ref', 'bold_std_ref'),
              ('outputnode.bold_mask_std', 'bold_mask_std')]),
        ])

        if freesurfer:
            workflow.connect([
                (bold_std_trans_wf, func_derivatives_wf, [
                    ('outputnode.bold_aseg_std', 'inputnode.bold_aseg_std'),
                    ('outputnode.bold_aparc_std', 'inputnode.bold_aparc_std'),
                ]),
                (bold_std_trans_wf, outputnode,
                 [('outputnode.bold_aseg_std', 'bold_aseg_std'),
                  ('outputnode.bold_aparc_std', 'bold_aparc_std')]),
            ])

        if not multiecho:
            # TODO: Add SDC
            workflow.connect([
                (bold_split, bold_std_trans_wf, [('out_files',
                                                  'inputnode.bold_split')]),
                # (bold_sdc_wf, bold_std_trans_wf, [
                #     ('outputnode.out_warp', 'inputnode.fieldwarp')]),
                (bold_hmc_wf, bold_std_trans_wf, [('outputnode.xforms',
                                                   'inputnode.hmc_xforms')]),
            ])
        else:
            workflow.connect([(split_opt_comb, bold_std_trans_wf,
                               [('out_files', 'inputnode.bold_split')])])

            # Already applied in bold_bold_trans_wf, which inputs to bold_t2s_wf
            bold_std_trans_wf.inputs.inputnode.fieldwarp = 'identity'
            bold_std_trans_wf.inputs.inputnode.hmc_xforms = 'identity'

        # func_derivatives_wf internally parametrizes over snapshotted spaces.
        workflow.connect([
            (bold_std_trans_wf, func_derivatives_wf, [
                ('outputnode.template', 'inputnode.template'),
                ('outputnode.spatial_reference',
                 'inputnode.spatial_reference'),
                ('outputnode.bold_std_ref', 'inputnode.bold_std_ref'),
                ('outputnode.bold_std', 'inputnode.bold_std'),
                ('outputnode.bold_mask_std', 'inputnode.bold_mask_std'),
            ]),
        ])

        if config.workflow.use_aroma:  # ICA-AROMA workflow
            from .confounds import init_ica_aroma_wf
            ica_aroma_wf = init_ica_aroma_wf(
                mem_gb=mem_gb['resampled'],
                metadata=metadata,
                omp_nthreads=omp_nthreads,
                err_on_aroma_warn=config.workflow.aroma_err_on_warn,
                aroma_melodic_dim=config.workflow.aroma_melodic_dim,
                name='ica_aroma_wf')

            join = pe.Node(niu.Function(output_names=["out_file"],
                                        function=_to_join),
                           name='aroma_confounds')

            mrg_conf_metadata = pe.Node(niu.Merge(2),
                                        name='merge_confound_metadata',
                                        run_without_submitting=True)
            mrg_conf_metadata2 = pe.Node(DictMerge(),
                                         name='merge_confound_metadata2',
                                         run_without_submitting=True)
            workflow.disconnect([
                (bold_confounds_wf, outputnode, [
                    ('outputnode.confounds_file', 'confounds'),
                ]),
                (bold_confounds_wf, outputnode, [
                    ('outputnode.confounds_metadata', 'confounds_metadata'),
                ]),
            ])
            workflow.connect([
                (inputnode, ica_aroma_wf, [('bold_file',
                                            'inputnode.name_source')]),
                (bold_hmc_wf, ica_aroma_wf, [('outputnode.movpar_file',
                                              'inputnode.movpar_file')]),
                (inputnode, ica_aroma_wf, [('n_dummy_scans',
                                            'inputnode.skip_vols')]),
                (bold_confounds_wf, join, [('outputnode.confounds_file',
                                            'in_file')]),
                (bold_confounds_wf, mrg_conf_metadata,
                 [('outputnode.confounds_metadata', 'in1')]),
                (ica_aroma_wf, join, [('outputnode.aroma_confounds',
                                       'join_file')]),
                (ica_aroma_wf, mrg_conf_metadata,
                 [('outputnode.aroma_metadata', 'in2')]),
                (mrg_conf_metadata, mrg_conf_metadata2, [('out', 'in_dicts')]),
                (ica_aroma_wf, outputnode,
                 [('outputnode.aroma_noise_ics', 'aroma_noise_ics'),
                  ('outputnode.melodic_mix', 'melodic_mix'),
                  ('outputnode.nonaggr_denoised_file', 'nonaggr_denoised_file')
                  ]),
                (join, outputnode, [('out_file', 'confounds')]),
                (mrg_conf_metadata2, outputnode, [('out_dict',
                                                   'confounds_metadata')]),
                (bold_std_trans_wf, ica_aroma_wf,
                 [('outputnode.bold_std', 'inputnode.bold_std'),
                  ('outputnode.bold_mask_std', 'inputnode.bold_mask_std'),
                  ('outputnode.spatial_reference',
                   'inputnode.spatial_reference')]),
            ])

    # SURFACES ##################################################################################
    # Freesurfer
    freesurfer_spaces = spaces.get_fs_spaces()
    if freesurfer and freesurfer_spaces:
        config.loggers.workflow.debug(
            'Creating BOLD surface-sampling workflow.')
        bold_surf_wf = init_bold_surf_wf(
            mem_gb=mem_gb['resampled'],
            surface_spaces=freesurfer_spaces,
            medial_surface_nan=config.workflow.medial_surface_nan,
            name='bold_surf_wf')
        workflow.connect([
            (inputnode, bold_surf_wf,
             [('subjects_dir', 'inputnode.subjects_dir'),
              ('subject_id', 'inputnode.subject_id'),
              ('anat2fsnative_xfm', 'inputnode.t1w2fsnative_xfm')]),
            (bold_t1_trans_wf, bold_surf_wf, [('outputnode.bold_t1',
                                               'inputnode.source_file')]),
            (bold_surf_wf, outputnode, [('outputnode.surfaces', 'surfaces')]),
            (bold_surf_wf, func_derivatives_wf, [('outputnode.target',
                                                  'inputnode.surf_refs')]),
        ])

        # CIFTI output
        if config.workflow.cifti_output:
            from .resampling import init_bold_grayords_wf
            bold_grayords_wf = init_bold_grayords_wf(
                grayord_density=config.workflow.cifti_output,
                mem_gb=mem_gb['resampled'],
                repetition_time=metadata['RepetitionTime'])

            workflow.connect([
                (inputnode, bold_grayords_wf, [('subjects_dir',
                                                'inputnode.subjects_dir')]),
                (bold_std_trans_wf, bold_grayords_wf,
                 [('outputnode.bold_std', 'inputnode.bold_std'),
                  ('outputnode.spatial_reference',
                   'inputnode.spatial_reference')]),
                (bold_surf_wf, bold_grayords_wf, [
                    ('outputnode.surfaces', 'inputnode.surf_files'),
                    ('outputnode.target', 'inputnode.surf_refs'),
                ]),
                (bold_grayords_wf, outputnode,
                 [('outputnode.cifti_bold', 'bold_cifti'),
                  ('outputnode.cifti_variant', 'cifti_variant'),
                  ('outputnode.cifti_metadata', 'cifti_metadata'),
                  ('outputnode.cifti_density', 'cifti_density')]),
            ])

    if spaces.get_spaces(nonstandard=False, dim=(3, )):
        if not config.workflow.cifti_output:
            config.loggers.workflow.critical(
                "The carpetplot requires CIFTI outputs")
        else:
            carpetplot_wf = init_carpetplot_wf(
                mem_gb=mem_gb['resampled'],
                metadata=metadata,
                cifti_output=bool(config.workflow.cifti_output),
                name='carpetplot_wf')

            workflow.connect([
                (bold_grayords_wf, carpetplot_wf, [('outputnode.cifti_bold',
                                                    'inputnode.cifti_bold')]),
                (bold_confounds_wf, carpetplot_wf,
                 [('outputnode.confounds_file', 'inputnode.confounds_file')]),
            ])

    # REPORTING ############################################################
    ds_report_summary = pe.Node(DerivativesDataSink(
        desc='summary', datatype="figures", dismiss_entities=("echo", )),
                                name='ds_report_summary',
                                run_without_submitting=True,
                                mem_gb=config.DEFAULT_MEMORY_MIN_GB)

    ds_report_validation = pe.Node(DerivativesDataSink(
        desc='validation', datatype="figures", dismiss_entities=("echo", )),
                                   name='ds_report_validation',
                                   run_without_submitting=True,
                                   mem_gb=config.DEFAULT_MEMORY_MIN_GB)

    workflow.connect([
        (summary, ds_report_summary, [('out_report', 'in_file')]),
        (val_bold, ds_report_validation, [(("out_report", pop_file), 'in_file')
                                          ]),
    ])

    # Fill-in datasinks of reportlets seen so far
    for node in workflow.list_node_names():
        if node.split('.')[-1].startswith('ds_report'):
            workflow.get_node(node).inputs.base_directory = nibabies_dir
            workflow.get_node(node).inputs.source_file = ref_file

    # Distortion correction
    if not has_fieldmap:
        # fmt: off
        # Finalize workflow with fieldmap-less connections
        workflow.connect([
            (inputnode, final_boldref_masker, [('bold_ref', 'in_file')]),
            (final_boldref_masker, bold_t1_trans_wf, [
                ('out_mask', 'inputnode.ref_bold_mask'),
                ('out_file', 'inputnode.ref_bold_brain'),
            ]),
            (final_boldref_masker, bold_reg_wf, [
                ('out_file', 'inputnode.ref_bold_brain'),
            ]),
            (final_boldref_masker, bold_confounds_wf,
             [('out_mask', 'inputnode.bold_mask')]),
        ])

        if nonstd_spaces.intersection(('T1w', 'anat')):
            workflow.connect([
                (final_boldref_masker, boldmask_to_t1w, [('out_mask',
                                                          'input_image')]),
            ])
        #         (final_boldref_wf, boldmask_to_t1w, [('outputnode.bold_mask', 'input_image')]),
        #     ])

        if nonstd_spaces.intersection(
            ('func', 'run', 'bold', 'boldref', 'sbref')):
            workflow.connect([
                (final_boldref_masker, func_derivatives_wf,
                 [('out_file', 'inputnode.bold_native_ref'),
                  ('out_mask', 'inputnode.bold_mask_native')]),
            ])
        #         (final_boldref_wf, func_derivatives_wf, [
        #             ('outputnode.ref_image', 'inputnode.bold_native_ref'),
        #             ('outputnode.bold_mask', 'inputnode.bold_mask_native')]),
        #     ])

        if spaces.get_spaces(nonstandard=False, dim=(3, )):
            workflow.connect([
                (final_boldref_masker, bold_std_trans_wf,
                 [('out_mask', 'inputnode.bold_mask')]),
            ])
        #         (final_boldref_wf, bold_std_trans_wf, [
        #             ('outputnode.bold_mask', 'inputnode.bold_mask')]),
        #     ])

        # fmt: on
        return workflow

    from niworkflows.interfaces.reportlets.registration import (
        SimpleBeforeAfterRPT as SimpleBeforeAfter, )
    from niworkflows.interfaces.utility import KeySelect
    from sdcflows.workflows.apply.registration import init_coeff2epi_wf
    from sdcflows.workflows.apply.correction import init_unwarp_wf

    coeff2epi_wf = init_coeff2epi_wf(
        debug="fieldmaps" in config.execution.debug,
        omp_nthreads=config.nipype.omp_nthreads,
        write_coeff=True,
    )
    unwarp_wf = init_unwarp_wf(debug="fieldmaps" in config.execution.debug,
                               omp_nthreads=config.nipype.omp_nthreads)
    unwarp_wf.inputs.inputnode.metadata = layout.get_metadata(str(bold_file))

    output_select = pe.Node(
        KeySelect(fields=["fmap", "fmap_ref", "fmap_coeff", "fmap_mask"]),
        name="output_select",
        run_without_submitting=True,
    )
    output_select.inputs.key = estimator_key[0]
    if len(estimator_key) > 1:
        config.loggers.workflow.warning(
            f"Several fieldmaps <{', '.join(estimator_key)}> are "
            f"'IntendedFor' <{bold_file}>, using {estimator_key[0]}")

    sdc_report = pe.Node(
        SimpleBeforeAfter(before_label="Distorted", after_label="Corrected"),
        name="sdc_report",
        mem_gb=0.1,
    )

    ds_report_sdc = pe.Node(
        DerivativesDataSink(base_directory=nibabies_dir,
                            desc="sdc",
                            suffix="bold",
                            datatype="figures",
                            dismiss_entities=("echo", )),
        name="ds_report_sdc",
        run_without_submitting=True,
    )

    unwarp_masker = pe.Node(BrainExtraction(), name='unwarp_masker')

    # fmt: off
    workflow.connect([
        (inputnode, output_select, [("fmap", "fmap"), ("fmap_ref", "fmap_ref"),
                                    ("fmap_coeff", "fmap_coeff"),
                                    ("fmap_mask", "fmap_mask"),
                                    ("fmap_id", "keys")]),
        (output_select, coeff2epi_wf, [("fmap_ref", "inputnode.fmap_ref"),
                                       ("fmap_coeff", "inputnode.fmap_coeff"),
                                       ("fmap_mask", "inputnode.fmap_mask")]),
        (inputnode, coeff2epi_wf, [("bold_ref", "inputnode.target_ref")]),
        (final_boldref_masker, coeff2epi_wf, [("out_file",
                                               "inputnode.target_mask")]),
        (inputnode, unwarp_wf, [("bold_ref", "inputnode.distorted")]),
        (coeff2epi_wf, unwarp_wf, [("outputnode.fmap_coeff",
                                    "inputnode.fmap_coeff")]),
        (inputnode, sdc_report, [("bold_ref", "before")]),
        (unwarp_wf, sdc_report, [("outputnode.corrected", "after"),
                                 ("outputnode.corrected_mask", "wm_seg")]),
        (inputnode, ds_report_sdc, [("bold_file", "source_file")]),
        (sdc_report, ds_report_sdc, [("out_report", "in_file")]),
        # remaining workflow connections
        (unwarp_wf, unwarp_masker, [('outputnode.corrected', 'in_file')]),
        (unwarp_masker, bold_confounds_wf, [('out_mask', 'inputnode.bold_mask')
                                            ]),
        (unwarp_masker, bold_t1_trans_wf,
         [('out_mask', 'inputnode.ref_bold_mask'),
          ('out_file', 'inputnode.ref_bold_brain')]),
        # (unwarp_masker, bold_bold_trans_wf, [
        #     ('out_mask', 'inputnode.bold_mask')]),  # Not used within workflow
        (unwarp_masker, bold_reg_wf, [('out_file', 'inputnode.ref_bold_brain')]
         ),
        # TODO: Add distortion correction method to sdcflow outputs?
        # (bold_sdc_wf, summary, [('outputnode.method', 'distortion_correction')]),
    ])

    if nonstd_spaces.intersection(('T1w', 'anat')):
        workflow.connect([
            (unwarp_masker, boldmask_to_t1w, [('out_mask', 'input_image')]),
        ])

    if nonstd_spaces.intersection(('func', 'run', 'bold', 'boldref', 'sbref')):
        workflow.connect([
            (unwarp_masker, func_derivatives_wf,
             [('out_file', 'inputnode.bold_native_ref'),
              ('out_mask', 'inputnode.bold_mask_native')]),
        ])

    if spaces.get_spaces(nonstandard=False, dim=(3, )):
        workflow.connect([
            (unwarp_masker, bold_std_trans_wf, [('out_mask',
                                                 'inputnode.bold_mask')]),
        ])
    # fmt: on

    # if not multiecho:
    #     (bold_sdc_wf, bold_t1_trans_wf, [
    #             ('outputnode.out_warp', 'inputnode.fieldwarp')])
    #     (bold_sdc_wf, bold_std_trans_wf, [
    #         ('outputnode.out_warp', 'inputnode.fieldwarp')]),
    # ])
    return workflow
Beispiel #8
0
def test_registration_wf(tmpdir, datadir, workdir, outdir):
    """Test fieldmap-to-target alignment workflow."""
    epi_ref_wf = init_magnitude_wf(2, name="epi_ref_wf")
    epi_ref_wf.inputs.inputnode.magnitude = (
        datadir / "HCP101006" / "sub-101006" / "func" /
        "sub-101006_task-rest_dir-LR_sbref.nii.gz")

    magnitude = (datadir / "HCP101006" / "sub-101006" / "fmap" /
                 "sub-101006_magnitude1.nii.gz")
    fmap_ref_wf = init_magnitude_wf(2, name="fmap_ref_wf")
    fmap_ref_wf.inputs.inputnode.magnitude = magnitude

    gen_coeff = pe.Node(niu.Function(function=_gen_coeff), name="gen_coeff")

    reg_wf = init_coeff2epi_wf(2, debug=True, write_coeff=True)

    workflow = pe.Workflow(name="test_registration_wf")
    # fmt: off
    workflow.connect([
        (epi_ref_wf, reg_wf, [
            ("outputnode.fmap_ref", "inputnode.target_ref"),
            ("outputnode.fmap_mask", "inputnode.target_mask"),
        ]),
        (fmap_ref_wf, reg_wf, [
            ("outputnode.fmap_ref", "inputnode.fmap_ref"),
            ("outputnode.fmap_mask", "inputnode.fmap_mask"),
        ]),
        (fmap_ref_wf, gen_coeff, [("outputnode.fmap_ref", "img")]),
        (gen_coeff, reg_wf, [("out", "inputnode.fmap_coeff")]),
    ])
    # fmt: on

    if outdir:
        from niworkflows.interfaces.reportlets.registration import (
            SimpleBeforeAfterRPT as SimpleBeforeAfter, )
        from ...outputs import DerivativesDataSink

        report = pe.Node(
            SimpleBeforeAfter(
                after_label="Target EPI",
                before_label="B0 Reference",
            ),
            name="report",
            mem_gb=0.1,
        )
        ds_report = pe.Node(
            DerivativesDataSink(
                base_directory=str(outdir),
                suffix="fieldmap",
                space="sbref",
                datatype="figures",
                dismiss_entities=("fmap", ),
                source_file=magnitude,
            ),
            name="ds_report",
            run_without_submitting=True,
        )

        # fmt: off
        workflow.connect([
            (fmap_ref_wf, report, [("outputnode.fmap_ref", "before")]),
            (reg_wf, report, [("outputnode.target_ref", "after")]),
            (report, ds_report, [("out_report", "in_file")]),
        ])
        # fmt: on

    if workdir:
        workflow.base_dir = str(workdir)

    workflow.run(plugin="Linear")
Beispiel #9
0
def init_dwi_preproc_wf(dwi_file, has_fieldmap=False):
    """
    Build a preprocessing workflow for one DWI run.

    Workflow Graph
        .. workflow::
            :graph2use: orig
            :simple_form: yes

            from dmriprep.config.testing import mock_config
            from dmriprep import config
            from dmriprep.workflows.dwi.base import init_dwi_preproc_wf
            with mock_config():
                wf = init_dwi_preproc_wf(
                    f"{config.execution.layout.root}/"
                    "sub-THP0005/dwi/sub-THP0005_dwi.nii.gz"
                )

    Parameters
    ----------
    dwi_file : :obj:`os.PathLike`
        One diffusion MRI dataset to be processed.
    has_fieldmap : :obj:`bool`
        Build the workflow with a path to register a fieldmap to the DWI.

    Inputs
    ------
    dwi_file
        dwi NIfTI file
    in_bvec
        File path of the b-vectors
    in_bval
        File path of the b-values
    fmap
        File path of the fieldmap
    fmap_ref
        File path of the fieldmap reference
    fmap_coeff
        File path of the fieldmap coefficients
    fmap_mask
        File path of the fieldmap mask
    fmap_id
        The BIDS modality label of the fieldmap being used

    Outputs
    -------
    dwi_reference
        A 3D :math:`b = 0` reference, before susceptibility distortion correction.
    dwi_mask
        A 3D, binary mask of the ``dwi_reference`` above.
    gradients_rasb
        A *RASb* (RAS+ coordinates, scaled b-values, normalized b-vectors, BIDS-compatible)
        gradient table.

    See Also
    --------
    * :py:func:`~dmriprep.workflows.dwi.util.init_dwi_reference_wf`
    * :py:func:`~dmriprep.workflows.dwi.outputs.init_dwi_derivatives_wf`
    * :py:func:`~dmriprep.workflows.dwi.outputs.init_reportlets_wf`

    """
    from niworkflows.interfaces.reportlets.registration import (
        SimpleBeforeAfterRPT as SimpleBeforeAfter, )
    from ...interfaces.vectors import CheckGradientTable
    from .util import init_dwi_reference_wf
    from .outputs import init_dwi_derivatives_wf, init_reportlets_wf
    from .eddy import init_eddy_wf

    layout = config.execution.layout

    dwi_file = Path(dwi_file)
    config.loggers.workflow.debug(
        f"Creating DWI preprocessing workflow for <{dwi_file.name}>")

    if has_fieldmap:
        import re
        from sdcflows.fieldmaps import get_identifier

        dwi_rel = re.sub(r"^sub-[a-zA-Z0-9]*/", "",
                         str(dwi_file.relative_to(layout.root)))
        estimator_key = get_identifier(dwi_rel)
        if not estimator_key:
            has_fieldmap = False
            config.loggers.workflow.critical(
                f"None of the available B0 fieldmaps are associated to <{dwi_rel}>"
            )

    # Build workflow
    workflow = Workflow(name=_get_wf_name(dwi_file.name))

    inputnode = pe.Node(
        niu.IdentityInterface(fields=[
            # DWI
            "dwi_file",
            "in_bvec",
            "in_bval",
            # From SDCFlows
            "fmap",
            "fmap_ref",
            "fmap_coeff",
            "fmap_mask",
            "fmap_id",
            # From anatomical
            "t1w_preproc",
            "t1w_mask",
            "t1w_dseg",
            "t1w_aseg",
            "t1w_aparc",
            "t1w_tpms",
            "template",
            "anat2std_xfm",
            "std2anat_xfm",
            "subjects_dir",
            "subject_id",
            "t1w2fsnative_xfm",
            "fsnative2t1w_xfm",
        ]),
        name="inputnode",
    )
    inputnode.inputs.dwi_file = str(dwi_file.absolute())
    inputnode.inputs.in_bvec = str(layout.get_bvec(dwi_file))
    inputnode.inputs.in_bval = str(layout.get_bval(dwi_file))

    outputnode = pe.Node(
        niu.IdentityInterface(
            fields=["dwi_reference", "dwi_mask", "gradients_rasb"]),
        name="outputnode",
    )

    gradient_table = pe.Node(CheckGradientTable(), name="gradient_table")

    dwi_reference_wf = init_dwi_reference_wf(
        mem_gb=config.DEFAULT_MEMORY_MIN_GB,
        omp_nthreads=config.nipype.omp_nthreads)

    dwi_derivatives_wf = init_dwi_derivatives_wf(
        output_dir=str(config.execution.output_dir))

    # MAIN WORKFLOW STRUCTURE
    # fmt: off
    workflow.connect([
        (inputnode, gradient_table, [("dwi_file", "dwi_file"),
                                     ("in_bvec", "in_bvec"),
                                     ("in_bval", "in_bval")]),
        (inputnode, dwi_reference_wf, [("dwi_file", "inputnode.dwi_file")]),
        (inputnode, dwi_derivatives_wf, [("dwi_file", "inputnode.source_file")
                                         ]),
        (gradient_table, dwi_reference_wf, [("b0_ixs", "inputnode.b0_ixs")]),
        (gradient_table, outputnode, [("out_rasb", "gradients_rasb")]),
        (outputnode, dwi_derivatives_wf, [
            ("dwi_reference", "inputnode.dwi_ref"),
            ("dwi_mask", "inputnode.dwi_mask"),
        ]),
    ])
    # fmt: on

    if config.workflow.run_reconall:
        from niworkflows.interfaces.nibabel import ApplyMask
        from niworkflows.anat.coregistration import init_bbreg_wf
        from ...utils.misc import sub_prefix as _prefix

        # Mask the T1w
        t1w_brain = pe.Node(ApplyMask(), name="t1w_brain")

        bbr_wf = init_bbreg_wf(
            debug=config.execution.debug,
            epi2t1w_init=config.workflow.dwi2t1w_init,
            omp_nthreads=config.nipype.omp_nthreads,
        )

        ds_report_reg = pe.Node(
            DerivativesDataSink(
                base_directory=str(config.execution.output_dir),
                datatype="figures",
            ),
            name="ds_report_reg",
            run_without_submitting=True,
        )

        def _bold_reg_suffix(fallback):
            return "coreg" if fallback else "bbregister"

        # fmt: off
        workflow.connect([
            (inputnode, bbr_wf, [
                ("fsnative2t1w_xfm", "inputnode.fsnative2t1w_xfm"),
                (("subject_id", _prefix), "inputnode.subject_id"),
                ("subjects_dir", "inputnode.subjects_dir"),
            ]),
            # T1w Mask
            (inputnode, t1w_brain, [("t1w_preproc", "in_file"),
                                    ("t1w_mask", "in_mask")]),
            (inputnode, ds_report_reg, [("dwi_file", "source_file")]),
            # BBRegister
            (dwi_reference_wf, bbr_wf, [("outputnode.ref_image",
                                         "inputnode.in_file")]),
            (bbr_wf, ds_report_reg, [('outputnode.out_report', 'in_file'),
                                     (('outputnode.fallback',
                                       _bold_reg_suffix), 'desc')]),
        ])
        # fmt: on

    if "eddy" not in config.workflow.ignore:
        # Eddy distortion correction
        eddy_wf = init_eddy_wf(debug=config.execution.debug)
        eddy_wf.inputs.inputnode.metadata = layout.get_metadata(str(dwi_file))

        ds_report_eddy = pe.Node(
            DerivativesDataSink(
                base_directory=str(config.execution.output_dir),
                desc="eddy",
                datatype="figures",
            ),
            name="ds_report_eddy",
            run_without_submitting=True,
        )

        eddy_report = pe.Node(
            SimpleBeforeAfter(
                before_label="Distorted",
                after_label="Eddy Corrected",
            ),
            name="eddy_report",
            mem_gb=0.1,
        )

        # fmt:off
        workflow.connect([
            (dwi_reference_wf, eddy_wf, [
                ("outputnode.dwi_file", "inputnode.dwi_file"),
                ("outputnode.dwi_mask", "inputnode.dwi_mask"),
            ]),
            (inputnode, eddy_wf, [("in_bvec", "inputnode.in_bvec"),
                                  ("in_bval", "inputnode.in_bval")]),
            (dwi_reference_wf, eddy_report, [("outputnode.ref_image", "before")
                                             ]),
            (eddy_wf, eddy_report, [('outputnode.eddy_ref_image', 'after')]),
            (dwi_reference_wf, ds_report_eddy, [("outputnode.dwi_file",
                                                 "source_file")]),
            (eddy_report, ds_report_eddy, [("out_report", "in_file")]),
        ])
        # fmt:on

    # REPORTING ############################################################
    reportlets_wf = init_reportlets_wf(
        str(config.execution.output_dir),
        sdc_report=has_fieldmap,
    )
    # fmt: off
    workflow.connect([
        (inputnode, reportlets_wf, [("dwi_file", "inputnode.source_file")]),
        (dwi_reference_wf, reportlets_wf, [
            ("outputnode.validation_report", "inputnode.validation_report"),
        ]),
        (outputnode, reportlets_wf, [
            ("dwi_reference", "inputnode.dwi_ref"),
            ("dwi_mask", "inputnode.dwi_mask"),
        ]),
    ])
    # fmt: on

    if not has_fieldmap:
        # fmt: off
        workflow.connect([
            (dwi_reference_wf, outputnode,
             [("outputnode.ref_image", "dwi_reference"),
              ("outputnode.dwi_mask", "dwi_mask")]),
        ])
        # fmt: on
        return workflow

    from niworkflows.interfaces.utility import KeySelect
    from sdcflows.workflows.apply.registration import init_coeff2epi_wf
    from sdcflows.workflows.apply.correction import init_unwarp_wf

    coeff2epi_wf = init_coeff2epi_wf(
        debug=config.execution.debug,
        omp_nthreads=config.nipype.omp_nthreads,
        write_coeff=True,
    )
    unwarp_wf = init_unwarp_wf(debug=config.execution.debug,
                               omp_nthreads=config.nipype.omp_nthreads)
    unwarp_wf.inputs.inputnode.metadata = layout.get_metadata(str(dwi_file))

    output_select = pe.Node(
        KeySelect(fields=["fmap", "fmap_ref", "fmap_coeff", "fmap_mask"]),
        name="output_select",
        run_without_submitting=True,
    )
    output_select.inputs.key = estimator_key[0]
    if len(estimator_key) > 1:
        config.loggers.workflow.warning(
            f"Several fieldmaps <{', '.join(estimator_key)}> are "
            f"'IntendedFor' <{dwi_file}>, using {estimator_key[0]}")

    sdc_report = pe.Node(
        SimpleBeforeAfter(
            before_label="Distorted",
            after_label="Corrected",
        ),
        name="sdc_report",
        mem_gb=0.1,
    )

    # fmt: off
    workflow.connect([
        (inputnode, output_select, [("fmap", "fmap"), ("fmap_ref", "fmap_ref"),
                                    ("fmap_coeff", "fmap_coeff"),
                                    ("fmap_mask", "fmap_mask"),
                                    ("fmap_id", "keys")]),
        (output_select, coeff2epi_wf, [("fmap_ref", "inputnode.fmap_ref"),
                                       ("fmap_coeff", "inputnode.fmap_coeff"),
                                       ("fmap_mask", "inputnode.fmap_mask")]),
        (dwi_reference_wf, coeff2epi_wf,
         [("outputnode.ref_image", "inputnode.target_ref"),
          ("outputnode.dwi_mask", "inputnode.target_mask")]),
        (dwi_reference_wf, unwarp_wf, [("outputnode.ref_image",
                                        "inputnode.distorted")]),
        (coeff2epi_wf, unwarp_wf, [("outputnode.fmap_coeff",
                                    "inputnode.fmap_coeff")]),
        (dwi_reference_wf, sdc_report, [("outputnode.ref_image", "before")]),
        (unwarp_wf, sdc_report, [("outputnode.corrected", "after"),
                                 ("outputnode.corrected_mask", "wm_seg")]),
        (sdc_report, reportlets_wf, [("out_report", "inputnode.sdc_report")]),
        (unwarp_wf, outputnode, [("outputnode.corrected", "dwi_reference"),
                                 ("outputnode.corrected_mask", "dwi_mask")]),
    ])
    # fmt: on

    return workflow