Exemple #1
0
def init_atropos_wf(
    name="atropos_wf",
    use_random_seed=True,
    omp_nthreads=None,
    mem_gb=3.0,
    padding=10,
    in_segmentation_model=tuple(ATROPOS_MODELS["T1w"].values()),
    bspline_fitting_distance=200,
    wm_prior=False,
):
    """
    Create an ANTs' ATROPOS workflow for brain tissue segmentation.

    Re-interprets supersteps 6 and 7 of ``antsBrainExtraction.sh``,
    which refine the mask previously computed with the spatial
    normalization to the template.
    The workflow also executes steps 8 and 9 of the brain extraction
    workflow.

    Workflow Graph
        .. workflow::
            :graph2use: orig
            :simple_form: yes

            from niworkflows.anat.ants import init_atropos_wf
            wf = init_atropos_wf()

    Parameters
    ----------
    name : str, optional
        Workflow name (default: "atropos_wf").
    use_random_seed : bool
        Whether ATROPOS should generate a random seed based on the
        system's clock
    omp_nthreads : int
        Maximum number of threads an individual process may use
    mem_gb : float
        Estimated peak memory consumption of the most hungry nodes
        in the workflow
    padding : int
        Pad images with zeros before processing
    in_segmentation_model : tuple
        A k-means segmentation is run to find gray or white matter
        around the edge of the initial brain mask warped from the
        template.
        This produces a segmentation image with :math:`$K$` classes,
        ordered by mean intensity in increasing order.
        With this option, you can control  :math:`$K$` and tell the script which
        classes represent CSF, gray and white matter.
        Format (K, csfLabel, gmLabel, wmLabel).
        Examples:
        ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default),
        ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1,
        ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2,
        ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3.
    bspline_fitting_distance : float
        The size of the b-spline mesh grid elements, in mm (default: 200)
    wm_prior : :obj:`bool`
        Whether the WM posterior obtained with ATROPOS should be regularized with a prior
        map (typically, mapped from the template). When ``wm_prior`` is ``True`` the input
        field ``wm_prior`` of the input node must be connected.

    Inputs
    ------
    in_files : list
        The original anatomical images passed in to the brain-extraction workflow.
    in_corrected : list
        :abbr:`INU (intensity non-uniformity)`-corrected files.
    in_mask : str
        Brain mask calculated previously.
    wm_prior : :obj:`str`
        Path to the WM prior probability map, aligned with the individual data.

    Outputs
    -------
    out_file : :obj:`str`
        Path of the corrected and brain-extracted result, using the ATROPOS refinement.
    bias_corrected : :obj:`str`
        Path of the corrected and result, using the ATROPOS refinement.
    bias_image : :obj:`str`
        Path of the estimated INU bias field, using the ATROPOS refinement.
    out_mask : str
        Refined brain mask
    out_segm : str
        Output segmentation
    out_tpms : str
        Output :abbr:`TPMs (tissue probability maps)`


    """
    wf = pe.Workflow(name)

    out_fields = [
        "bias_corrected", "bias_image", "out_mask", "out_segm", "out_tpms"
    ]

    inputnode = pe.Node(
        niu.IdentityInterface(
            fields=["in_files", "in_corrected", "in_mask", "wm_prior"]),
        name="inputnode",
    )
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"] +
                                               out_fields),
                         name="outputnode")

    copy_xform = pe.Node(CopyXForm(fields=out_fields),
                         name="copy_xform",
                         run_without_submitting=True)

    # Morphological dilation, radius=2
    dil_brainmask = pe.Node(ImageMath(operation="MD",
                                      op2="2",
                                      copy_header=True),
                            name="dil_brainmask")
    # Get largest connected component
    get_brainmask = pe.Node(
        ImageMath(operation="GetLargestComponent", copy_header=True),
        name="get_brainmask",
    )

    # Run atropos (core node)
    atropos = pe.Node(
        Atropos(
            convergence_threshold=0.0,
            dimension=3,
            initialization="KMeans",
            likelihood_model="Gaussian",
            mrf_radius=[1, 1, 1],
            mrf_smoothing_factor=0.1,
            n_iterations=3,
            number_of_tissue_classes=in_segmentation_model[0],
            save_posteriors=True,
            use_random_seed=use_random_seed,
        ),
        name="01_atropos",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )

    # massage outputs
    pad_segm = pe.Node(
        ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False),
        name="02_pad_segm",
    )
    pad_mask = pe.Node(
        ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False),
        name="03_pad_mask",
    )

    # Split segmentation in binary masks
    sel_labels = pe.Node(
        niu.Function(function=_select_labels,
                     output_names=["out_wm", "out_gm", "out_csf"]),
        name="04_sel_labels",
    )
    sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:]))

    # Select largest components (GM, WM)
    # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM}
    get_wm = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="05_get_wm")
    get_gm = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="06_get_gm")

    # Fill holes and calculate intersection
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM}
    fill_gm = pe.Node(ImageMath(operation="FillHoles", op2="2"),
                      name="07_fill_gm")
    mult_gm = pe.Node(
        MultiplyImages(dimension=3, output_product_image="08_mult_gm.nii.gz"),
        name="08_mult_gm",
    )

    # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM}
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10
    relabel_wm = pe.Node(
        MultiplyImages(
            dimension=3,
            second_input=in_segmentation_model[-1],
            output_product_image="09_relabel_wm.nii.gz",
        ),
        name="09_relabel_wm",
    )
    me_csf = pe.Node(ImageMath(operation="ME", op2="10"), name="10_me_csf")

    # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP}
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM}
    # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM}
    add_gm = pe.Node(ImageMath(operation="addtozero"), name="11_add_gm")
    relabel_gm = pe.Node(
        MultiplyImages(
            dimension=3,
            second_input=in_segmentation_model[-2],
            output_product_image="12_relabel_gm.nii.gz",
        ),
        name="12_relabel_gm",
    )
    add_gm_wm = pe.Node(ImageMath(operation="addtozero"), name="13_add_gm_wm")

    # Superstep 7
    # Split segmentation in binary masks
    sel_labels2 = pe.Node(
        niu.Function(function=_select_labels,
                     output_names=["out_gm", "out_wm"]),
        name="14_sel_labels2",
    )
    sel_labels2.inputs.labels = in_segmentation_model[2:]

    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP}
    add_7 = pe.Node(ImageMath(operation="addtozero"), name="15_add_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2
    me_7 = pe.Node(ImageMath(operation="ME", op2="2"), name="16_me_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK}
    comp_7 = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="17_comp_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4
    md_7 = pe.Node(ImageMath(operation="MD", op2="4"), name="18_md_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2
    fill_7 = pe.Node(ImageMath(operation="FillHoles", op2="2"),
                     name="19_fill_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \
    # ${EXTRACTION_MASK_PRIOR_WARPED}
    add_7_2 = pe.Node(ImageMath(operation="addtozero"), name="20_add_7_2")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5
    md_7_2 = pe.Node(ImageMath(operation="MD", op2="5"), name="21_md_7_2")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5
    me_7_2 = pe.Node(ImageMath(operation="ME", op2="5"), name="22_me_7_2")

    # De-pad
    depad_mask = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                         name="23_depad_mask")
    depad_segm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                         name="24_depad_segm")
    depad_gm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                       name="25_depad_gm")
    depad_wm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                       name="26_depad_wm")
    depad_csf = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                        name="27_depad_csf")

    msk_conform = pe.Node(niu.Function(function=_conform_mask),
                          name="msk_conform")
    merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]),
                         name="merge_tpms")

    sel_wm = pe.Node(niu.Select(), name="sel_wm", run_without_submitting=True)
    if not wm_prior:
        sel_wm.inputs.index = in_segmentation_model[-1] - 1

    copy_xform_wm = pe.Node(CopyXForm(fields=["wm_map"]),
                            name="copy_xform_wm",
                            run_without_submitting=True)

    # Refine INU correction
    inu_n4_final = pe.MapNode(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=True,
            copy_header=True,
            n_iterations=[50] * 5,
            convergence_threshold=1e-7,
            shrink_factor=4,
            bspline_fitting_distance=bspline_fitting_distance,
        ),
        n_procs=omp_nthreads,
        name="inu_n4_final",
        iterfield=["input_image"],
    )

    try:
        inu_n4_final.inputs.rescale_intensities = True
    except ValueError:
        warn(
            "N4BiasFieldCorrection's --rescale-intensities option was added in ANTS 2.1.0 "
            f"({inu_n4_final.interface.version} found.) Please consider upgrading.",
            UserWarning,
        )

    # Apply mask
    apply_mask = pe.MapNode(ApplyMask(),
                            iterfield=["in_file"],
                            name="apply_mask")

    # fmt: off
    wf.connect([
        (inputnode, dil_brainmask, [("in_mask", "op1")]),
        (inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]),
        (inputnode, copy_xform_wm, [(("in_files", _pop), "hdr_file")]),
        (inputnode, pad_mask, [("in_mask", "op1")]),
        (inputnode, atropos, [("in_corrected", "intensity_images")]),
        (inputnode, inu_n4_final, [("in_files", "input_image")]),
        (inputnode, msk_conform, [(("in_files", _pop), "in_reference")]),
        (dil_brainmask, get_brainmask, [("output_image", "op1")]),
        (get_brainmask, atropos, [("output_image", "mask_image")]),
        (atropos, pad_segm, [("classified_image", "op1")]),
        (pad_segm, sel_labels, [("output_image", "in_segm")]),
        (sel_labels, get_wm, [("out_wm", "op1")]),
        (sel_labels, get_gm, [("out_gm", "op1")]),
        (get_gm, fill_gm, [("output_image", "op1")]),
        (get_gm, mult_gm, [("output_image", "first_input")]),
        (fill_gm, mult_gm, [("output_image", "second_input")]),
        (get_wm, relabel_wm, [("output_image", "first_input")]),
        (sel_labels, me_csf, [("out_csf", "op1")]),
        (mult_gm, add_gm, [("output_product_image", "op1")]),
        (me_csf, add_gm, [("output_image", "op2")]),
        (add_gm, relabel_gm, [("output_image", "first_input")]),
        (relabel_wm, add_gm_wm, [("output_product_image", "op1")]),
        (relabel_gm, add_gm_wm, [("output_product_image", "op2")]),
        (add_gm_wm, sel_labels2, [("output_image", "in_segm")]),
        (sel_labels2, add_7, [("out_wm", "op1"), ("out_gm", "op2")]),
        (add_7, me_7, [("output_image", "op1")]),
        (me_7, comp_7, [("output_image", "op1")]),
        (comp_7, md_7, [("output_image", "op1")]),
        (md_7, fill_7, [("output_image", "op1")]),
        (fill_7, add_7_2, [("output_image", "op1")]),
        (pad_mask, add_7_2, [("output_image", "op2")]),
        (add_7_2, md_7_2, [("output_image", "op1")]),
        (md_7_2, me_7_2, [("output_image", "op1")]),
        (me_7_2, depad_mask, [("output_image", "op1")]),
        (add_gm_wm, depad_segm, [("output_image", "op1")]),
        (relabel_wm, depad_wm, [("output_product_image", "op1")]),
        (relabel_gm, depad_gm, [("output_product_image", "op1")]),
        (sel_labels, depad_csf, [("out_csf", "op1")]),
        (depad_csf, merge_tpms, [("output_image", "in1")]),
        (depad_gm, merge_tpms, [("output_image", "in2")]),
        (depad_wm, merge_tpms, [("output_image", "in3")]),
        (depad_mask, msk_conform, [("output_image", "in_mask")]),
        (msk_conform, copy_xform, [("out", "out_mask")]),
        (depad_segm, copy_xform, [("output_image", "out_segm")]),
        (merge_tpms, copy_xform, [("out", "out_tpms")]),
        (atropos, sel_wm, [("posteriors", "inlist")]),
        (sel_wm, copy_xform_wm, [("out", "wm_map")]),
        (copy_xform_wm, inu_n4_final, [("wm_map", "weight_image")]),
        (inu_n4_final, copy_xform, [("output_image", "bias_corrected"),
                                    ("bias_image", "bias_image")]),
        (copy_xform, apply_mask, [("bias_corrected", "in_file"),
                                  ("out_mask", "in_mask")]),
        (apply_mask, outputnode, [("out_file", "out_file")]),
        (copy_xform, outputnode, [
            ("bias_corrected", "bias_corrected"),
            ("bias_image", "bias_image"),
            ("out_mask", "out_mask"),
            ("out_segm", "out_segm"),
            ("out_tpms", "out_tpms"),
        ]),
    ])
    # fmt: on

    if wm_prior:
        from nipype.algorithms.metrics import FuzzyOverlap

        def _argmax(in_dice):
            import numpy as np

            return np.argmax(in_dice)

        match_wm = pe.Node(
            niu.Function(function=_matchlen),
            name="match_wm",
            run_without_submitting=True,
        )
        overlap = pe.Node(FuzzyOverlap(),
                          name="overlap",
                          run_without_submitting=True)

        apply_wm_prior = pe.Node(niu.Function(function=_improd),
                                 name="apply_wm_prior")

        # fmt: off
        wf.disconnect([
            (copy_xform_wm, inu_n4_final, [("wm_map", "weight_image")]),
        ])
        wf.connect([
            (inputnode, apply_wm_prior, [("in_mask", "in_mask"),
                                         ("wm_prior", "op2")]),
            (inputnode, match_wm, [("wm_prior", "value")]),
            (atropos, match_wm, [("posteriors", "reference")]),
            (atropos, overlap, [("posteriors", "in_ref")]),
            (match_wm, overlap, [("out", "in_tst")]),
            (overlap, sel_wm, [(("class_fdi", _argmax), "index")]),
            (copy_xform_wm, apply_wm_prior, [("wm_map", "op1")]),
            (apply_wm_prior, inu_n4_final, [("out", "weight_image")]),
        ])
        # fmt: on
    return wf
Exemple #2
0
def init_atropos_wf(name='atropos_wf',
                    use_random_seed=True,
                    omp_nthreads=None,
                    mem_gb=3.0,
                    padding=10,
                    in_segmentation_model=list(
                        ATROPOS_MODELS['T1w'].values())):
    """
    Implements supersteps 6 and 7 of ``antsBrainExtraction.sh``,
    which refine the mask previously computed with the spatial
    normalization to the template.
    **Parameters**
        use_random_seed : bool
            Whether ATROPOS should generate a random seed based on the
            system's clock
        omp_nthreads : int
            Maximum number of threads an individual process may use
        mem_gb : float
            Estimated peak memory consumption of the most hungry nodes
            in the workflow
        padding : int
            Pad images with zeros before processing
        in_segmentation_model : tuple
            A k-means segmentation is run to find gray or white matter
            around the edge of the initial brain mask warped from the
            template.
            This produces a segmentation image with :math:`$K$` classes,
            ordered by mean intensity in increasing order.
            With this option, you can control  :math:`$K$` and tell
            the script which classes represent CSF, gray and white matter.
            Format (K, csfLabel, gmLabel, wmLabel).
            Examples:
              - ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default)
              - ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1
              - ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2
              - ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3
        name : str, optional
            Workflow name (default: atropos_wf)
    **Inputs**
        in_files
            :abbr:`INU (intensity non-uniformity)`-corrected files.
        in_mask
            Brain mask calculated previously
    **Outputs**
        out_mask
            Refined brain mask
        out_segm
            Output segmentation
        out_tpms
            Output :abbr:`TPMs (tissue probability maps)`
    """
    wf = pe.Workflow(name)

    inputnode = pe.Node(niu.IdentityInterface(
        fields=['in_files', 'in_mask', 'in_mask_dilated']),
                        name='inputnode')
    outputnode = pe.Node(
        niu.IdentityInterface(fields=['out_mask', 'out_segm', 'out_tpms']),
        name='outputnode')

    copy_xform = pe.Node(
        CopyXForm(fields=['out_mask', 'out_segm', 'out_tpms']),
        name='copy_xform',
        run_without_submitting=True,
        mem_gb=2.5)

    # Run atropos (core node)
    atropos = pe.Node(Atropos(
        dimension=3,
        initialization='KMeans',
        number_of_tissue_classes=in_segmentation_model[0],
        n_iterations=3,
        convergence_threshold=0.0,
        mrf_radius=[1, 1, 1],
        mrf_smoothing_factor=0.1,
        likelihood_model='Gaussian',
        use_random_seed=use_random_seed),
                      name='01_atropos',
                      n_procs=omp_nthreads,
                      mem_gb=mem_gb)

    # massage outputs
    pad_segm = pe.Node(ImageMath(operation='PadImage', op2='%d' % padding),
                       name='02_pad_segm')
    pad_mask = pe.Node(ImageMath(operation='PadImage', op2='%d' % padding),
                       name='03_pad_mask')

    # Split segmentation in binary masks
    sel_labels = pe.Node(niu.Function(
        function=_select_labels, output_names=['out_wm', 'out_gm', 'out_csf']),
                         name='04_sel_labels')
    sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:]))

    # Select largest components (GM, WM)
    # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM}
    get_wm = pe.Node(ImageMath(operation='GetLargestComponent'),
                     name='05_get_wm')
    get_gm = pe.Node(ImageMath(operation='GetLargestComponent'),
                     name='06_get_gm')

    # Fill holes and calculate intersection
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM}
    fill_gm = pe.Node(ImageMath(operation='FillHoles', op2='2'),
                      name='07_fill_gm')
    mult_gm = pe.Node(MultiplyImages(dimension=3,
                                     output_product_image='08_mult_gm.nii.gz'),
                      name='08_mult_gm')

    # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM}
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10
    relabel_wm = pe.Node(MultiplyImages(
        dimension=3,
        second_input=in_segmentation_model[-1],
        output_product_image='09_relabel_wm.nii.gz'),
                         name='09_relabel_wm')
    me_csf = pe.Node(ImageMath(operation='ME', op2='10'), name='10_me_csf')

    # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP}
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM}
    # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM}
    add_gm = pe.Node(ImageMath(operation='addtozero'), name='11_add_gm')
    relabel_gm = pe.Node(MultiplyImages(
        dimension=3,
        second_input=in_segmentation_model[-2],
        output_product_image='12_relabel_gm.nii.gz'),
                         name='12_relabel_gm')
    add_gm_wm = pe.Node(ImageMath(operation='addtozero'), name='13_add_gm_wm')

    # Superstep 7
    # Split segmentation in binary masks
    sel_labels2 = pe.Node(niu.Function(function=_select_labels,
                                       output_names=['out_gm', 'out_wm']),
                          name='14_sel_labels2')
    sel_labels2.inputs.labels = in_segmentation_model[2:]

    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP}
    add_7 = pe.Node(ImageMath(operation='addtozero'), name='15_add_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2
    me_7 = pe.Node(ImageMath(operation='ME', op2='2'), name='16_me_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK}
    comp_7 = pe.Node(ImageMath(operation='GetLargestComponent'),
                     name='17_comp_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4
    md_7 = pe.Node(ImageMath(operation='MD', op2='4'), name='18_md_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2
    fill_7 = pe.Node(ImageMath(operation='FillHoles', op2='2'),
                     name='19_fill_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \
    # ${EXTRACTION_MASK_PRIOR_WARPED}
    add_7_2 = pe.Node(ImageMath(operation='addtozero'), name='20_add_7_2')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5
    md_7_2 = pe.Node(ImageMath(operation='MD', op2='5'), name='21_md_7_2')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5
    me_7_2 = pe.Node(ImageMath(operation='ME', op2='5'), name='22_me_7_2')

    # De-pad
    depad_mask = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                         name='23_depad_mask')
    depad_segm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                         name='24_depad_segm')
    depad_gm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                       name='25_depad_gm')
    depad_wm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                       name='26_depad_wm')
    depad_csf = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                        name='27_depad_csf')

    msk_conform = pe.Node(niu.Function(function=_conform_mask),
                          name='msk_conform')
    merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]),
                         name='merge_tpms')
    wf.connect([
        (inputnode, copy_xform, [(('in_files', _pop), 'hdr_file')]),
        (inputnode, pad_mask, [('in_mask', 'op1')]),
        (inputnode, atropos, [('in_files', 'intensity_images'),
                              ('in_mask_dilated', 'mask_image')]),
        (inputnode, msk_conform, [(('in_files', _pop), 'in_reference')]),
        (atropos, pad_segm, [('classified_image', 'op1')]),
        (pad_segm, sel_labels, [('output_image', 'in_segm')]),
        (sel_labels, get_wm, [('out_wm', 'op1')]),
        (sel_labels, get_gm, [('out_gm', 'op1')]),
        (get_gm, fill_gm, [('output_image', 'op1')]),
        (get_gm, mult_gm, [('output_image', 'first_input')]),
        (fill_gm, mult_gm, [('output_image', 'second_input')]),
        (get_wm, relabel_wm, [('output_image', 'first_input')]),
        (sel_labels, me_csf, [('out_csf', 'op1')]),
        (mult_gm, add_gm, [('output_product_image', 'op1')]),
        (me_csf, add_gm, [('output_image', 'op2')]),
        (add_gm, relabel_gm, [('output_image', 'first_input')]),
        (relabel_wm, add_gm_wm, [('output_product_image', 'op1')]),
        (relabel_gm, add_gm_wm, [('output_product_image', 'op2')]),
        (add_gm_wm, sel_labels2, [('output_image', 'in_segm')]),
        (sel_labels2, add_7, [('out_wm', 'op1'), ('out_gm', 'op2')]),
        (add_7, me_7, [('output_image', 'op1')]),
        (me_7, comp_7, [('output_image', 'op1')]),
        (comp_7, md_7, [('output_image', 'op1')]),
        (md_7, fill_7, [('output_image', 'op1')]),
        (fill_7, add_7_2, [('output_image', 'op1')]),
        (pad_mask, add_7_2, [('output_image', 'op2')]),
        (add_7_2, md_7_2, [('output_image', 'op1')]),
        (md_7_2, me_7_2, [('output_image', 'op1')]),
        (me_7_2, depad_mask, [('output_image', 'op1')]),
        (add_gm_wm, depad_segm, [('output_image', 'op1')]),
        (relabel_wm, depad_wm, [('output_product_image', 'op1')]),
        (relabel_gm, depad_gm, [('output_product_image', 'op1')]),
        (sel_labels, depad_csf, [('out_csf', 'op1')]),
        (depad_csf, merge_tpms, [('output_image', 'in1')]),
        (depad_gm, merge_tpms, [('output_image', 'in2')]),
        (depad_wm, merge_tpms, [('output_image', 'in3')]),
        (depad_mask, msk_conform, [('output_image', 'in_mask')]),
        (msk_conform, copy_xform, [('out', 'out_mask')]),
        (depad_segm, copy_xform, [('output_image', 'out_segm')]),
        (merge_tpms, copy_xform, [('out', 'out_tpms')]),
        (copy_xform, outputnode, [('out_mask', 'out_mask'),
                                  ('out_segm', 'out_segm'),
                                  ('out_tpms', 'out_tpms')]),
    ])
    return wf
Exemple #3
0
def init_atropos_wf(
        name="atropos_wf",
        use_random_seed=True,
        omp_nthreads=None,
        mem_gb=3.0,
        padding=10,
        in_segmentation_model=tuple(ATROPOS_MODELS["T1w"].values()),
):
    """
    Create an ANTs' ATROPOS workflow for brain tissue segmentation.

    Implements supersteps 6 and 7 of ``antsBrainExtraction.sh``,
    which refine the mask previously computed with the spatial
    normalization to the template.

    Workflow Graph
        .. workflow::
            :graph2use: orig
            :simple_form: yes

            from niworkflows.anat.ants import init_atropos_wf
            wf = init_atropos_wf()

    Parameters
    ----------
    use_random_seed : bool
        Whether ATROPOS should generate a random seed based on the
        system's clock
    omp_nthreads : int
        Maximum number of threads an individual process may use
    mem_gb : float
        Estimated peak memory consumption of the most hungry nodes
        in the workflow
    padding : int
        Pad images with zeros before processing
    in_segmentation_model : tuple
        A k-means segmentation is run to find gray or white matter
        around the edge of the initial brain mask warped from the
        template.
        This produces a segmentation image with :math:`$K$` classes,
        ordered by mean intensity in increasing order.
        With this option, you can control  :math:`$K$` and tell the script which
        classes represent CSF, gray and white matter.
        Format (K, csfLabel, gmLabel, wmLabel).
        Examples:
        ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default),
        ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1,
        ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2,
        ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3.
    name : str, optional
        Workflow name (default: "atropos_wf").

    Inputs
    ------
    in_files : list
        :abbr:`INU (intensity non-uniformity)`-corrected files.
    in_mask : str
        Brain mask calculated previously.

    Outputs
    -------
    out_mask : str
        Refined brain mask
    out_segm : str
        Output segmentation
    out_tpms : str
        Output :abbr:`TPMs (tissue probability maps)`


    """
    wf = pe.Workflow(name)

    inputnode = pe.Node(
        niu.IdentityInterface(
            fields=["in_files", "in_mask", "in_mask_dilated"]),
        name="inputnode",
    )
    outputnode = pe.Node(
        niu.IdentityInterface(fields=["out_mask", "out_segm", "out_tpms"]),
        name="outputnode",
    )

    copy_xform = pe.Node(
        CopyXForm(fields=["out_mask", "out_segm", "out_tpms"]),
        name="copy_xform",
        run_without_submitting=True,
    )

    # Run atropos (core node)
    atropos = pe.Node(
        Atropos(
            dimension=3,
            initialization="KMeans",
            number_of_tissue_classes=in_segmentation_model[0],
            n_iterations=3,
            convergence_threshold=0.0,
            mrf_radius=[1, 1, 1],
            mrf_smoothing_factor=0.1,
            likelihood_model="Gaussian",
            use_random_seed=use_random_seed,
        ),
        name="01_atropos",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )

    # massage outputs
    pad_segm = pe.Node(ImageMath(operation="PadImage", op2="%d" % padding),
                       name="02_pad_segm")
    pad_mask = pe.Node(ImageMath(operation="PadImage", op2="%d" % padding),
                       name="03_pad_mask")

    # Split segmentation in binary masks
    sel_labels = pe.Node(
        niu.Function(function=_select_labels,
                     output_names=["out_wm", "out_gm", "out_csf"]),
        name="04_sel_labels",
    )
    sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:]))

    # Select largest components (GM, WM)
    # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM}
    get_wm = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="05_get_wm")
    get_gm = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="06_get_gm")

    # Fill holes and calculate intersection
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM}
    fill_gm = pe.Node(ImageMath(operation="FillHoles", op2="2"),
                      name="07_fill_gm")
    mult_gm = pe.Node(
        MultiplyImages(dimension=3, output_product_image="08_mult_gm.nii.gz"),
        name="08_mult_gm",
    )

    # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM}
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10
    relabel_wm = pe.Node(
        MultiplyImages(
            dimension=3,
            second_input=in_segmentation_model[-1],
            output_product_image="09_relabel_wm.nii.gz",
        ),
        name="09_relabel_wm",
    )
    me_csf = pe.Node(ImageMath(operation="ME", op2="10"), name="10_me_csf")

    # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP}
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM}
    # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM}
    add_gm = pe.Node(ImageMath(operation="addtozero"), name="11_add_gm")
    relabel_gm = pe.Node(
        MultiplyImages(
            dimension=3,
            second_input=in_segmentation_model[-2],
            output_product_image="12_relabel_gm.nii.gz",
        ),
        name="12_relabel_gm",
    )
    add_gm_wm = pe.Node(ImageMath(operation="addtozero"), name="13_add_gm_wm")

    # Superstep 7
    # Split segmentation in binary masks
    sel_labels2 = pe.Node(
        niu.Function(function=_select_labels,
                     output_names=["out_gm", "out_wm"]),
        name="14_sel_labels2",
    )
    sel_labels2.inputs.labels = in_segmentation_model[2:]

    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP}
    add_7 = pe.Node(ImageMath(operation="addtozero"), name="15_add_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2
    me_7 = pe.Node(ImageMath(operation="ME", op2="2"), name="16_me_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK}
    comp_7 = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="17_comp_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4
    md_7 = pe.Node(ImageMath(operation="MD", op2="4"), name="18_md_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2
    fill_7 = pe.Node(ImageMath(operation="FillHoles", op2="2"),
                     name="19_fill_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \
    # ${EXTRACTION_MASK_PRIOR_WARPED}
    add_7_2 = pe.Node(ImageMath(operation="addtozero"), name="20_add_7_2")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5
    md_7_2 = pe.Node(ImageMath(operation="MD", op2="5"), name="21_md_7_2")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5
    me_7_2 = pe.Node(ImageMath(operation="ME", op2="5"), name="22_me_7_2")

    # De-pad
    depad_mask = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                         name="23_depad_mask")
    depad_segm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                         name="24_depad_segm")
    depad_gm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                       name="25_depad_gm")
    depad_wm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                       name="26_depad_wm")
    depad_csf = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                        name="27_depad_csf")

    msk_conform = pe.Node(niu.Function(function=_conform_mask),
                          name="msk_conform")
    merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]),
                         name="merge_tpms")
    # fmt: off
    wf.connect([
        (inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]),
        (inputnode, pad_mask, [("in_mask", "op1")]),
        (inputnode, atropos, [
            ("in_files", "intensity_images"),
            ("in_mask_dilated", "mask_image"),
        ]),
        (inputnode, msk_conform, [(("in_files", _pop), "in_reference")]),
        (atropos, pad_segm, [("classified_image", "op1")]),
        (pad_segm, sel_labels, [("output_image", "in_segm")]),
        (sel_labels, get_wm, [("out_wm", "op1")]),
        (sel_labels, get_gm, [("out_gm", "op1")]),
        (get_gm, fill_gm, [("output_image", "op1")]),
        (get_gm, mult_gm, [("output_image", "first_input")]),
        (fill_gm, mult_gm, [("output_image", "second_input")]),
        (get_wm, relabel_wm, [("output_image", "first_input")]),
        (sel_labels, me_csf, [("out_csf", "op1")]),
        (mult_gm, add_gm, [("output_product_image", "op1")]),
        (me_csf, add_gm, [("output_image", "op2")]),
        (add_gm, relabel_gm, [("output_image", "first_input")]),
        (relabel_wm, add_gm_wm, [("output_product_image", "op1")]),
        (relabel_gm, add_gm_wm, [("output_product_image", "op2")]),
        (add_gm_wm, sel_labels2, [("output_image", "in_segm")]),
        (sel_labels2, add_7, [("out_wm", "op1"), ("out_gm", "op2")]),
        (add_7, me_7, [("output_image", "op1")]),
        (me_7, comp_7, [("output_image", "op1")]),
        (comp_7, md_7, [("output_image", "op1")]),
        (md_7, fill_7, [("output_image", "op1")]),
        (fill_7, add_7_2, [("output_image", "op1")]),
        (pad_mask, add_7_2, [("output_image", "op2")]),
        (add_7_2, md_7_2, [("output_image", "op1")]),
        (md_7_2, me_7_2, [("output_image", "op1")]),
        (me_7_2, depad_mask, [("output_image", "op1")]),
        (add_gm_wm, depad_segm, [("output_image", "op1")]),
        (relabel_wm, depad_wm, [("output_product_image", "op1")]),
        (relabel_gm, depad_gm, [("output_product_image", "op1")]),
        (sel_labels, depad_csf, [("out_csf", "op1")]),
        (depad_csf, merge_tpms, [("output_image", "in1")]),
        (depad_gm, merge_tpms, [("output_image", "in2")]),
        (depad_wm, merge_tpms, [("output_image", "in3")]),
        (depad_mask, msk_conform, [("output_image", "in_mask")]),
        (msk_conform, copy_xform, [("out", "out_mask")]),
        (depad_segm, copy_xform, [("output_image", "out_segm")]),
        (merge_tpms, copy_xform, [("out", "out_tpms")]),
        (copy_xform, outputnode, [
            ("out_mask", "out_mask"),
            ("out_segm", "out_segm"),
            ("out_tpms", "out_tpms"),
        ]),
    ])
    # fmt: on
    return wf