def antsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix=''):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image :
           inputspec.ListOfPassiveImagesDictionaries :
           inputspec.interpolationMapping :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates :
    """
    TemplateBuildSingleIterationWF = pe.Workflow(
        name='antsRegistrationTemplateBuildSingleIterationWF_' +
        str(iterationPhasePrefix))

    inputSpec = pe.Node(interface=util.IdentityInterface(fields=[
        'ListOfImagesDictionaries', 'registrationImageTypes',
        'interpolationMapping', 'fixed_image'
    ]),
                        run_without_submitting=True,
                        name='inputspec')
    ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: TODO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: TODO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['template', 'transforms_list', 'passive_deformed_templates']),
                         run_without_submitting=True,
                         name='outputspec')

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(interface=Registration(),
                           name='BeginANTS',
                           iterfield=['moving_image'])
    BeginANTS.inputs.dimension = 3
    BeginANTS.inputs.output_transform_prefix = str(
        iterationPhasePrefix) + '_tfm'
    BeginANTS.inputs.transforms = ["Affine", "SyN"]
    BeginANTS.inputs.transform_parameters = [[0.9], [0.25, 3.0, 0.0]]
    BeginANTS.inputs.metric = ['Mattes', 'CC']
    BeginANTS.inputs.metric_weight = [1.0, 1.0]
    BeginANTS.inputs.radius_or_number_of_bins = [32, 5]
    BeginANTS.inputs.number_of_iterations = [[1000, 1000, 1000], [50, 35, 15]]
    BeginANTS.inputs.use_histogram_matching = [True, True]
    BeginANTS.inputs.use_estimate_learning_rate_once = [False, False]
    BeginANTS.inputs.shrink_factors = [[3, 2, 1], [3, 2, 1]]
    BeginANTS.inputs.smoothing_sigmas = [[3, 2, 0], [3, 2, 0]]

    GetMovingImagesNode = pe.Node(interface=util.Function(
        function=GetMovingImages,
        input_names=[
            'ListOfImagesDictionaries', 'registrationImageTypes',
            'interpolationMapping'
        ],
        output_names=['moving_images', 'moving_interpolation_type']),
                                  run_without_submitting=True,
                                  name='99_GetMovingImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetMovingImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetMovingImagesNode,
                                           'registrationImageTypes')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           GetMovingImagesNode,
                                           'interpolationMapping')

    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', BeginANTS,
                                           'moving_image')
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_interpolation_type',
                                           BeginANTS, 'interpolation')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS,
                                           'fixed_image')

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=['transforms', 'invert_transform_flags', 'input_image'],
        name='wimtdeformed')
    wimtdeformed.inputs.interpolation = 'Linear'
    wimtdeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms',
                                           wimtdeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags',
                                           wimtdeformed,
                                           'invert_transform_flags')
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', wimtdeformed,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image',
                                           wimtdeformed, 'reference_image')

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(),
                                name='AvgDeformedImages')
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = str(
        iterationPhasePrefix) + '.nii.gz'
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image",
                                           AvgDeformedImages, 'images')

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(interface=AverageAffineTransform(),
                                 name='AvgAffineTransform')
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str(
        iterationPhasePrefix) + '_Affine.mat'

    SplitAffineAndWarpsNode = pe.Node(interface=util.Function(
        function=SplitAffineAndWarpComponents,
        input_names=['list_of_transforms_lists'],
        output_names=['affine_component_list', 'warp_component_list']),
                                      run_without_submitting=True,
                                      name='99_SplitAffineAndWarpsNode')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms',
                                           SplitAffineAndWarpsNode,
                                           'list_of_transforms_lists')
    TemplateBuildSingleIterationWF.connect(SplitAffineAndWarpsNode,
                                           'affine_component_list',
                                           AvgAffineTransform, 'transforms')

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages')
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = str(
        iterationPhasePrefix) + 'warp.nii.gz'
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(SplitAffineAndWarpsNode,
                                           'warp_component_list',
                                           AvgWarpImages, 'images')

    ## Now average the images together
    ## TODO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(interface=MultiplyImages(),
                                    name='GradientStepWarpImage')
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str(
        iterationPhasePrefix) + '_warp.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgWarpImages,
                                           'output_average_image',
                                           GradientStepWarpImage,
                                           'first_input')

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(interface=ApplyTransforms(),
                                  name='UpdateTemplateShape')
    UpdateTemplateShape.inputs.invert_transform_flags = [True]
    UpdateTemplateShape.inputs.interpolation = 'Linear'
    UpdateTemplateShape.default_value = 0

    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           UpdateTemplateShape,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect([
        (AvgAffineTransform, UpdateTemplateShape,
         [(('affine_transform', makeListOfOneElement), 'transforms')]),
    ])
    TemplateBuildSingleIterationWF.connect(GradientStepWarpImage,
                                           'output_product_image',
                                           UpdateTemplateShape, 'input_image')

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=MakeTransformListWithGradientWarps,
            input_names=['averageAffineTranform', 'gradientStepWarp'],
            output_names=['TransformListWithGradientWarps']),
        run_without_submitting=True,
        name='99_MakeTransformListWithGradientWarps')
    ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform, 'affine_transform',
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'averageAffineTranform')
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape, 'output_image',
        ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp')

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate')
    ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear'
    ReshapeAverageImageWithShapeUpdate.default_value = 0
    ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate,
        'transforms')
    TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate,
                                           'output_image', outputSpec,
                                           'template')

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(function=FlattenTransformAndImagesList,
                 input_names=[
                     'ListOfPassiveImagesDictionaries', 'transforms',
                     'invert_transform_flags', 'interpolationMapping'
                 ],
                 output_names=[
                     'flattened_images', 'flattened_transforms',
                     'flattened_invert_transform_flags',
                     'flattened_image_nametypes',
                     'flattened_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList")

    GetPassiveImagesNode = pe.Node(interface=util.Function(
        function=GetPassiveImages,
        input_names=['ListOfImagesDictionaries', 'registrationImageTypes'],
        output_names=['ListOfPassiveImagesDictionaries']),
                                   run_without_submitting=True,
                                   name='99_GetPassiveImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetPassiveImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetPassiveImagesNode,
                                           'registrationImageTypes')

    TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode,
                                           'ListOfPassiveImagesDictionaries',
                                           FlattenTransformAndImagesListNode,
                                           'ListOfPassiveImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           FlattenTransformAndImagesListNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms',
                                           FlattenTransformAndImagesListNode,
                                           'transforms')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags',
                                           FlattenTransformAndImagesListNode,
                                           'invert_transform_flags')
    wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(),
                                     iterfield=[
                                         'transforms',
                                         'invert_transform_flags',
                                         'input_image', 'interpolation'
                                     ],
                                     name='wimtPassivedeformed')
    wimtPassivedeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           wimtPassivedeformed,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_interpolation_type',
                                           wimtPassivedeformed,
                                           'interpolation')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_images',
                                           wimtPassivedeformed, 'input_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_transforms',
                                           wimtPassivedeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_invert_transform_flags',
                                           wimtPassivedeformed,
                                           'invert_transform_flags')

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(function=RenestDeformedPassiveImages,
                 input_names=[
                     'deformedPassiveImages', 'flattened_image_nametypes',
                     'interpolationMapping'
                 ],
                 output_names=[
                     'nested_imagetype_list', 'outputAverageImageName_list',
                     'image_type_list', 'nested_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages")
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           RenestDeformedPassiveImagesNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image',
                                           RenestDeformedPassiveImagesNode,
                                           'deformedPassiveImages')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_image_nametypes',
                                           RenestDeformedPassiveImagesNode,
                                           'flattened_image_nametypes')
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=['images', 'output_average_image'],
        name='AvgDeformedPassiveImages')
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "nested_imagetype_list",
                                           AvgDeformedPassiveImages, 'images')
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "outputAverageImageName_list",
                                           AvgDeformedPassiveImages,
                                           'output_average_image')

    ## -- TODO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=[
            'input_image', 'reference_image', 'output_image', 'interpolation'
        ],
        name='ReshapeAveragePassiveImageWithShapeUpdate')
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'nested_interpolation_type',
        ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation')
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'outputAverageImageName_list',
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'input_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps',
        ReshapeAveragePassiveImageWithShapeUpdate, 'transforms')
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec,
        'passive_deformed_templates')

    return TemplateBuildSingleIterationWF
Exemplo n.º 2
0
Arquivo: ants.py Projeto: gkiar/C-PAC
def init_atropos_wf(name='atropos_wf',
                    use_random_seed=True,
                    omp_nthreads=None,
                    mem_gb=3.0,
                    padding=10,
                    in_segmentation_model=list(
                        ATROPOS_MODELS['T1w'].values())):
    """
    Implements supersteps 6 and 7 of ``antsBrainExtraction.sh``,
    which refine the mask previously computed with the spatial
    normalization to the template.
    **Parameters**
        use_random_seed : bool
            Whether ATROPOS should generate a random seed based on the
            system's clock
        omp_nthreads : int
            Maximum number of threads an individual process may use
        mem_gb : float
            Estimated peak memory consumption of the most hungry nodes
            in the workflow
        padding : int
            Pad images with zeros before processing
        in_segmentation_model : tuple
            A k-means segmentation is run to find gray or white matter
            around the edge of the initial brain mask warped from the
            template.
            This produces a segmentation image with :math:`$K$` classes,
            ordered by mean intensity in increasing order.
            With this option, you can control  :math:`$K$` and tell
            the script which classes represent CSF, gray and white matter.
            Format (K, csfLabel, gmLabel, wmLabel).
            Examples:
              - ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default)
              - ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1
              - ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2
              - ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3
        name : str, optional
            Workflow name (default: atropos_wf)
    **Inputs**
        in_files
            :abbr:`INU (intensity non-uniformity)`-corrected files.
        in_mask
            Brain mask calculated previously
    **Outputs**
        out_mask
            Refined brain mask
        out_segm
            Output segmentation
        out_tpms
            Output :abbr:`TPMs (tissue probability maps)`
    """
    wf = pe.Workflow(name)

    inputnode = pe.Node(niu.IdentityInterface(
        fields=['in_files', 'in_mask', 'in_mask_dilated']),
                        name='inputnode')
    outputnode = pe.Node(
        niu.IdentityInterface(fields=['out_mask', 'out_segm', 'out_tpms']),
        name='outputnode')

    copy_xform = pe.Node(
        CopyXForm(fields=['out_mask', 'out_segm', 'out_tpms']),
        name='copy_xform',
        run_without_submitting=True,
        mem_gb=2.5)

    # Run atropos (core node)
    atropos = pe.Node(Atropos(
        dimension=3,
        initialization='KMeans',
        number_of_tissue_classes=in_segmentation_model[0],
        n_iterations=3,
        convergence_threshold=0.0,
        mrf_radius=[1, 1, 1],
        mrf_smoothing_factor=0.1,
        likelihood_model='Gaussian',
        use_random_seed=use_random_seed),
                      name='01_atropos',
                      n_procs=omp_nthreads,
                      mem_gb=mem_gb)

    # massage outputs
    pad_segm = pe.Node(ImageMath(operation='PadImage', op2='%d' % padding),
                       name='02_pad_segm')
    pad_mask = pe.Node(ImageMath(operation='PadImage', op2='%d' % padding),
                       name='03_pad_mask')

    # Split segmentation in binary masks
    sel_labels = pe.Node(niu.Function(
        function=_select_labels, output_names=['out_wm', 'out_gm', 'out_csf']),
                         name='04_sel_labels')
    sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:]))

    # Select largest components (GM, WM)
    # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM}
    get_wm = pe.Node(ImageMath(operation='GetLargestComponent'),
                     name='05_get_wm')
    get_gm = pe.Node(ImageMath(operation='GetLargestComponent'),
                     name='06_get_gm')

    # Fill holes and calculate intersection
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM}
    fill_gm = pe.Node(ImageMath(operation='FillHoles', op2='2'),
                      name='07_fill_gm')
    mult_gm = pe.Node(MultiplyImages(dimension=3,
                                     output_product_image='08_mult_gm.nii.gz'),
                      name='08_mult_gm')

    # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM}
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10
    relabel_wm = pe.Node(MultiplyImages(
        dimension=3,
        second_input=in_segmentation_model[-1],
        output_product_image='09_relabel_wm.nii.gz'),
                         name='09_relabel_wm')
    me_csf = pe.Node(ImageMath(operation='ME', op2='10'), name='10_me_csf')

    # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP}
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM}
    # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM}
    add_gm = pe.Node(ImageMath(operation='addtozero'), name='11_add_gm')
    relabel_gm = pe.Node(MultiplyImages(
        dimension=3,
        second_input=in_segmentation_model[-2],
        output_product_image='12_relabel_gm.nii.gz'),
                         name='12_relabel_gm')
    add_gm_wm = pe.Node(ImageMath(operation='addtozero'), name='13_add_gm_wm')

    # Superstep 7
    # Split segmentation in binary masks
    sel_labels2 = pe.Node(niu.Function(function=_select_labels,
                                       output_names=['out_gm', 'out_wm']),
                          name='14_sel_labels2')
    sel_labels2.inputs.labels = in_segmentation_model[2:]

    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP}
    add_7 = pe.Node(ImageMath(operation='addtozero'), name='15_add_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2
    me_7 = pe.Node(ImageMath(operation='ME', op2='2'), name='16_me_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK}
    comp_7 = pe.Node(ImageMath(operation='GetLargestComponent'),
                     name='17_comp_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4
    md_7 = pe.Node(ImageMath(operation='MD', op2='4'), name='18_md_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2
    fill_7 = pe.Node(ImageMath(operation='FillHoles', op2='2'),
                     name='19_fill_7')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \
    # ${EXTRACTION_MASK_PRIOR_WARPED}
    add_7_2 = pe.Node(ImageMath(operation='addtozero'), name='20_add_7_2')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5
    md_7_2 = pe.Node(ImageMath(operation='MD', op2='5'), name='21_md_7_2')
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5
    me_7_2 = pe.Node(ImageMath(operation='ME', op2='5'), name='22_me_7_2')

    # De-pad
    depad_mask = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                         name='23_depad_mask')
    depad_segm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                         name='24_depad_segm')
    depad_gm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                       name='25_depad_gm')
    depad_wm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                       name='26_depad_wm')
    depad_csf = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding),
                        name='27_depad_csf')

    msk_conform = pe.Node(niu.Function(function=_conform_mask),
                          name='msk_conform')
    merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]),
                         name='merge_tpms')
    wf.connect([
        (inputnode, copy_xform, [(('in_files', _pop), 'hdr_file')]),
        (inputnode, pad_mask, [('in_mask', 'op1')]),
        (inputnode, atropos, [('in_files', 'intensity_images'),
                              ('in_mask_dilated', 'mask_image')]),
        (inputnode, msk_conform, [(('in_files', _pop), 'in_reference')]),
        (atropos, pad_segm, [('classified_image', 'op1')]),
        (pad_segm, sel_labels, [('output_image', 'in_segm')]),
        (sel_labels, get_wm, [('out_wm', 'op1')]),
        (sel_labels, get_gm, [('out_gm', 'op1')]),
        (get_gm, fill_gm, [('output_image', 'op1')]),
        (get_gm, mult_gm, [('output_image', 'first_input')]),
        (fill_gm, mult_gm, [('output_image', 'second_input')]),
        (get_wm, relabel_wm, [('output_image', 'first_input')]),
        (sel_labels, me_csf, [('out_csf', 'op1')]),
        (mult_gm, add_gm, [('output_product_image', 'op1')]),
        (me_csf, add_gm, [('output_image', 'op2')]),
        (add_gm, relabel_gm, [('output_image', 'first_input')]),
        (relabel_wm, add_gm_wm, [('output_product_image', 'op1')]),
        (relabel_gm, add_gm_wm, [('output_product_image', 'op2')]),
        (add_gm_wm, sel_labels2, [('output_image', 'in_segm')]),
        (sel_labels2, add_7, [('out_wm', 'op1'), ('out_gm', 'op2')]),
        (add_7, me_7, [('output_image', 'op1')]),
        (me_7, comp_7, [('output_image', 'op1')]),
        (comp_7, md_7, [('output_image', 'op1')]),
        (md_7, fill_7, [('output_image', 'op1')]),
        (fill_7, add_7_2, [('output_image', 'op1')]),
        (pad_mask, add_7_2, [('output_image', 'op2')]),
        (add_7_2, md_7_2, [('output_image', 'op1')]),
        (md_7_2, me_7_2, [('output_image', 'op1')]),
        (me_7_2, depad_mask, [('output_image', 'op1')]),
        (add_gm_wm, depad_segm, [('output_image', 'op1')]),
        (relabel_wm, depad_wm, [('output_product_image', 'op1')]),
        (relabel_gm, depad_gm, [('output_product_image', 'op1')]),
        (sel_labels, depad_csf, [('out_csf', 'op1')]),
        (depad_csf, merge_tpms, [('output_image', 'in1')]),
        (depad_gm, merge_tpms, [('output_image', 'in2')]),
        (depad_wm, merge_tpms, [('output_image', 'in3')]),
        (depad_mask, msk_conform, [('output_image', 'in_mask')]),
        (msk_conform, copy_xform, [('out', 'out_mask')]),
        (depad_segm, copy_xform, [('output_image', 'out_segm')]),
        (merge_tpms, copy_xform, [('out', 'out_tpms')]),
        (copy_xform, outputnode, [('out_mask', 'out_mask'),
                                  ('out_segm', 'out_segm'),
                                  ('out_tpms', 'out_tpms')]),
    ])
    return wf
Exemplo n.º 3
0
def BAWantsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix,
                                                      CLUSTER_QUEUE,
                                                      CLUSTER_QUEUE_LONG):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image :
           inputspec.ListOfPassiveImagesDictionaries :
           inputspec.interpolationMapping :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates :
    """
    TemplateBuildSingleIterationWF = pe.Workflow(
        name='antsRegistrationTemplateBuildSingleIterationWF_' +
        str(iterationPhasePrefix))

    inputSpec = pe.Node(
        interface=util.IdentityInterface(fields=[
            'ListOfImagesDictionaries',
            'registrationImageTypes',
            # 'maskRegistrationImageType',
            'interpolationMapping',
            'fixed_image'
        ]),
        run_without_submitting=True,
        name='inputspec')
    ## HACK: TODO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0
    ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: TODO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: TODO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['template', 'transforms_list', 'passive_deformed_templates']),
                         run_without_submitting=True,
                         name='outputspec')

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(interface=Registration(),
                           name='BeginANTS',
                           iterfield=['moving_image'])
    # SEE template.py many_cpu_BeginANTS_options_dictionary = {'qsub_args': modify_qsub_args(CLUSTER_QUEUE,4,2,8), 'overwrite': True}
    ## This is set in the template.py file BeginANTS.plugin_args = BeginANTS_cpu_sge_options_dictionary
    CommonANTsRegistrationSettings(
        antsRegistrationNode=BeginANTS,
        registrationTypeDescription="SixStageAntsRegistrationT1Only",
        output_transform_prefix=str(iterationPhasePrefix) + '_tfm',
        output_warped_image='atlas2subject.nii.gz',
        output_inverse_warped_image='subject2atlas.nii.gz',
        save_state='SavedantsRegistrationNodeSyNState.h5',
        invert_initial_moving_transform=False,
        initial_moving_transform=None)

    GetMovingImagesNode = pe.Node(interface=util.Function(
        function=GetMovingImages,
        input_names=[
            'ListOfImagesDictionaries', 'registrationImageTypes',
            'interpolationMapping'
        ],
        output_names=['moving_images', 'moving_interpolation_type']),
                                  run_without_submitting=True,
                                  name='99_GetMovingImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetMovingImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetMovingImagesNode,
                                           'registrationImageTypes')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           GetMovingImagesNode,
                                           'interpolationMapping')

    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', BeginANTS,
                                           'moving_image')
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_interpolation_type',
                                           BeginANTS, 'interpolation')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS,
                                           'fixed_image')

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=['transforms', 'input_image'],
        # iterfield=['transforms', 'invert_transform_flags', 'input_image'],
        name='wimtdeformed')
    wimtdeformed.inputs.interpolation = 'Linear'
    wimtdeformed.default_value = 0
    # HACK: Should try using forward_composite_transform
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           wimtdeformed, 'transforms')
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags')
    ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms
    # wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False]
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', wimtdeformed,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image',
                                           wimtdeformed, 'reference_image')

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(),
                                name='AvgDeformedImages')
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = str(
        iterationPhasePrefix) + '.nii.gz'
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image",
                                           AvgDeformedImages, 'images')

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(interface=AverageAffineTransform(),
                                 name='AvgAffineTransform')
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str(
        iterationPhasePrefix) + '_Affine.h5'

    SplitCompositeTransform = pe.MapNode(interface=util.Function(
        function=SplitCompositeToComponentTransforms,
        input_names=['transformFilename'],
        output_names=['affine_component_list', 'warp_component_list']),
                                         iterfield=['transformFilename'],
                                         run_without_submitting=True,
                                         name='99_SplitCompositeTransform')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           SplitCompositeTransform,
                                           'transformFilename')
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'transformFilename')
    TemplateBuildSingleIterationWF.connect(SplitCompositeTransform,
                                           'affine_component_list',
                                           AvgAffineTransform, 'transforms')

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages')
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = str(
        iterationPhasePrefix) + 'warp.nii.gz'
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(SplitCompositeTransform,
                                           'warp_component_list',
                                           AvgWarpImages, 'images')

    ## Now average the images together
    ## TODO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(interface=MultiplyImages(),
                                    name='GradientStepWarpImage')
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str(
        iterationPhasePrefix) + '_warp.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgWarpImages,
                                           'output_average_image',
                                           GradientStepWarpImage,
                                           'first_input')

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(interface=ApplyTransforms(),
                                  name='UpdateTemplateShape')
    UpdateTemplateShape.inputs.invert_transform_flags = [True]
    UpdateTemplateShape.inputs.interpolation = 'Linear'
    UpdateTemplateShape.default_value = 0

    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           UpdateTemplateShape,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect([
        (AvgAffineTransform, UpdateTemplateShape,
         [(('affine_transform', makeListOfOneElement), 'transforms')]),
    ])
    TemplateBuildSingleIterationWF.connect(GradientStepWarpImage,
                                           'output_product_image',
                                           UpdateTemplateShape, 'input_image')

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=MakeTransformListWithGradientWarps,
            input_names=['averageAffineTranform', 'gradientStepWarp'],
            output_names=['TransformListWithGradientWarps']),
        run_without_submitting=True,
        name='99_MakeTransformListWithGradientWarps')
    ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform, 'affine_transform',
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'averageAffineTranform')
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape, 'output_image',
        ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp')

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate')
    ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear'
    ReshapeAverageImageWithShapeUpdate.default_value = 0
    ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate,
        'transforms')
    TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate,
                                           'output_image', outputSpec,
                                           'template')

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(function=FlattenTransformAndImagesList,
                 input_names=[
                     'ListOfPassiveImagesDictionaries', 'transforms',
                     'interpolationMapping', 'invert_transform_flags'
                 ],
                 output_names=[
                     'flattened_images', 'flattened_transforms',
                     'flattened_invert_transform_flags',
                     'flattened_image_nametypes',
                     'flattened_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList")

    GetPassiveImagesNode = pe.Node(interface=util.Function(
        function=GetPassiveImages,
        input_names=['ListOfImagesDictionaries', 'registrationImageTypes'],
        output_names=['ListOfPassiveImagesDictionaries']),
                                   run_without_submitting=True,
                                   name='99_GetPassiveImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetPassiveImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetPassiveImagesNode,
                                           'registrationImageTypes')

    TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode,
                                           'ListOfPassiveImagesDictionaries',
                                           FlattenTransformAndImagesListNode,
                                           'ListOfPassiveImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           FlattenTransformAndImagesListNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           FlattenTransformAndImagesListNode,
                                           'transforms')
    ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False]
    ## TODO: Please check of invert_transform_flags has a fixed number.
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags')
    wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(),
                                     iterfield=[
                                         'transforms',
                                         'invert_transform_flags',
                                         'input_image', 'interpolation'
                                     ],
                                     name='wimtPassivedeformed')
    wimtPassivedeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           wimtPassivedeformed,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_interpolation_type',
                                           wimtPassivedeformed,
                                           'interpolation')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_images',
                                           wimtPassivedeformed, 'input_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_transforms',
                                           wimtPassivedeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_invert_transform_flags',
                                           wimtPassivedeformed,
                                           'invert_transform_flags')

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(function=RenestDeformedPassiveImages,
                 input_names=[
                     'deformedPassiveImages', 'flattened_image_nametypes',
                     'interpolationMapping'
                 ],
                 output_names=[
                     'nested_imagetype_list', 'outputAverageImageName_list',
                     'image_type_list', 'nested_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages")
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           RenestDeformedPassiveImagesNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image',
                                           RenestDeformedPassiveImagesNode,
                                           'deformedPassiveImages')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_image_nametypes',
                                           RenestDeformedPassiveImagesNode,
                                           'flattened_image_nametypes')
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=['images', 'output_average_image'],
        name='AvgDeformedPassiveImages')
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "nested_imagetype_list",
                                           AvgDeformedPassiveImages, 'images')
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "outputAverageImageName_list",
                                           AvgDeformedPassiveImages,
                                           'output_average_image')

    ## -- TODO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=[
            'input_image', 'reference_image', 'output_image', 'interpolation'
        ],
        name='ReshapeAveragePassiveImageWithShapeUpdate')
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'nested_interpolation_type',
        ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation')
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'outputAverageImageName_list',
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'input_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps',
        ReshapeAveragePassiveImageWithShapeUpdate, 'transforms')
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec,
        'passive_deformed_templates')

    return TemplateBuildSingleIterationWF
Exemplo n.º 4
0
def init_atropos_wf(
    name="atropos_wf",
    use_random_seed=True,
    omp_nthreads=None,
    mem_gb=3.0,
    padding=10,
    in_segmentation_model=tuple(ATROPOS_MODELS["T1w"].values()),
    bspline_fitting_distance=200,
    wm_prior=False,
):
    """
    Create an ANTs' ATROPOS workflow for brain tissue segmentation.

    Re-interprets supersteps 6 and 7 of ``antsBrainExtraction.sh``,
    which refine the mask previously computed with the spatial
    normalization to the template.
    The workflow also executes steps 8 and 9 of the brain extraction
    workflow.

    Workflow Graph
        .. workflow::
            :graph2use: orig
            :simple_form: yes

            from niworkflows.anat.ants import init_atropos_wf
            wf = init_atropos_wf()

    Parameters
    ----------
    name : str, optional
        Workflow name (default: "atropos_wf").
    use_random_seed : bool
        Whether ATROPOS should generate a random seed based on the
        system's clock
    omp_nthreads : int
        Maximum number of threads an individual process may use
    mem_gb : float
        Estimated peak memory consumption of the most hungry nodes
        in the workflow
    padding : int
        Pad images with zeros before processing
    in_segmentation_model : tuple
        A k-means segmentation is run to find gray or white matter
        around the edge of the initial brain mask warped from the
        template.
        This produces a segmentation image with :math:`$K$` classes,
        ordered by mean intensity in increasing order.
        With this option, you can control  :math:`$K$` and tell the script which
        classes represent CSF, gray and white matter.
        Format (K, csfLabel, gmLabel, wmLabel).
        Examples:
        ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default),
        ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1,
        ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2,
        ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3.
    bspline_fitting_distance : float
        The size of the b-spline mesh grid elements, in mm (default: 200)
    wm_prior : :obj:`bool`
        Whether the WM posterior obtained with ATROPOS should be regularized with a prior
        map (typically, mapped from the template). When ``wm_prior`` is ``True`` the input
        field ``wm_prior`` of the input node must be connected.

    Inputs
    ------
    in_files : list
        The original anatomical images passed in to the brain-extraction workflow.
    in_corrected : list
        :abbr:`INU (intensity non-uniformity)`-corrected files.
    in_mask : str
        Brain mask calculated previously.
    wm_prior : :obj:`str`
        Path to the WM prior probability map, aligned with the individual data.

    Outputs
    -------
    out_file : :obj:`str`
        Path of the corrected and brain-extracted result, using the ATROPOS refinement.
    bias_corrected : :obj:`str`
        Path of the corrected and result, using the ATROPOS refinement.
    bias_image : :obj:`str`
        Path of the estimated INU bias field, using the ATROPOS refinement.
    out_mask : str
        Refined brain mask
    out_segm : str
        Output segmentation
    out_tpms : str
        Output :abbr:`TPMs (tissue probability maps)`


    """
    wf = pe.Workflow(name)

    out_fields = [
        "bias_corrected", "bias_image", "out_mask", "out_segm", "out_tpms"
    ]

    inputnode = pe.Node(
        niu.IdentityInterface(
            fields=["in_files", "in_corrected", "in_mask", "wm_prior"]),
        name="inputnode",
    )
    outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"] +
                                               out_fields),
                         name="outputnode")

    copy_xform = pe.Node(CopyXForm(fields=out_fields),
                         name="copy_xform",
                         run_without_submitting=True)

    # Morphological dilation, radius=2
    dil_brainmask = pe.Node(ImageMath(operation="MD",
                                      op2="2",
                                      copy_header=True),
                            name="dil_brainmask")
    # Get largest connected component
    get_brainmask = pe.Node(
        ImageMath(operation="GetLargestComponent", copy_header=True),
        name="get_brainmask",
    )

    # Run atropos (core node)
    atropos = pe.Node(
        Atropos(
            convergence_threshold=0.0,
            dimension=3,
            initialization="KMeans",
            likelihood_model="Gaussian",
            mrf_radius=[1, 1, 1],
            mrf_smoothing_factor=0.1,
            n_iterations=3,
            number_of_tissue_classes=in_segmentation_model[0],
            save_posteriors=True,
            use_random_seed=use_random_seed,
        ),
        name="01_atropos",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )

    # massage outputs
    pad_segm = pe.Node(
        ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False),
        name="02_pad_segm",
    )
    pad_mask = pe.Node(
        ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False),
        name="03_pad_mask",
    )

    # Split segmentation in binary masks
    sel_labels = pe.Node(
        niu.Function(function=_select_labels,
                     output_names=["out_wm", "out_gm", "out_csf"]),
        name="04_sel_labels",
    )
    sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:]))

    # Select largest components (GM, WM)
    # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM}
    get_wm = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="05_get_wm")
    get_gm = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="06_get_gm")

    # Fill holes and calculate intersection
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM}
    fill_gm = pe.Node(ImageMath(operation="FillHoles", op2="2"),
                      name="07_fill_gm")
    mult_gm = pe.Node(
        MultiplyImages(dimension=3, output_product_image="08_mult_gm.nii.gz"),
        name="08_mult_gm",
    )

    # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM}
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10
    relabel_wm = pe.Node(
        MultiplyImages(
            dimension=3,
            second_input=in_segmentation_model[-1],
            output_product_image="09_relabel_wm.nii.gz",
        ),
        name="09_relabel_wm",
    )
    me_csf = pe.Node(ImageMath(operation="ME", op2="10"), name="10_me_csf")

    # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP}
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM}
    # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM}
    add_gm = pe.Node(ImageMath(operation="addtozero"), name="11_add_gm")
    relabel_gm = pe.Node(
        MultiplyImages(
            dimension=3,
            second_input=in_segmentation_model[-2],
            output_product_image="12_relabel_gm.nii.gz",
        ),
        name="12_relabel_gm",
    )
    add_gm_wm = pe.Node(ImageMath(operation="addtozero"), name="13_add_gm_wm")

    # Superstep 7
    # Split segmentation in binary masks
    sel_labels2 = pe.Node(
        niu.Function(function=_select_labels,
                     output_names=["out_gm", "out_wm"]),
        name="14_sel_labels2",
    )
    sel_labels2.inputs.labels = in_segmentation_model[2:]

    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP}
    add_7 = pe.Node(ImageMath(operation="addtozero"), name="15_add_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2
    me_7 = pe.Node(ImageMath(operation="ME", op2="2"), name="16_me_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK}
    comp_7 = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="17_comp_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4
    md_7 = pe.Node(ImageMath(operation="MD", op2="4"), name="18_md_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2
    fill_7 = pe.Node(ImageMath(operation="FillHoles", op2="2"),
                     name="19_fill_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \
    # ${EXTRACTION_MASK_PRIOR_WARPED}
    add_7_2 = pe.Node(ImageMath(operation="addtozero"), name="20_add_7_2")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5
    md_7_2 = pe.Node(ImageMath(operation="MD", op2="5"), name="21_md_7_2")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5
    me_7_2 = pe.Node(ImageMath(operation="ME", op2="5"), name="22_me_7_2")

    # De-pad
    depad_mask = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                         name="23_depad_mask")
    depad_segm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                         name="24_depad_segm")
    depad_gm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                       name="25_depad_gm")
    depad_wm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                       name="26_depad_wm")
    depad_csf = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                        name="27_depad_csf")

    msk_conform = pe.Node(niu.Function(function=_conform_mask),
                          name="msk_conform")
    merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]),
                         name="merge_tpms")

    sel_wm = pe.Node(niu.Select(), name="sel_wm", run_without_submitting=True)
    if not wm_prior:
        sel_wm.inputs.index = in_segmentation_model[-1] - 1

    copy_xform_wm = pe.Node(CopyXForm(fields=["wm_map"]),
                            name="copy_xform_wm",
                            run_without_submitting=True)

    # Refine INU correction
    inu_n4_final = pe.MapNode(
        N4BiasFieldCorrection(
            dimension=3,
            save_bias=True,
            copy_header=True,
            n_iterations=[50] * 5,
            convergence_threshold=1e-7,
            shrink_factor=4,
            bspline_fitting_distance=bspline_fitting_distance,
        ),
        n_procs=omp_nthreads,
        name="inu_n4_final",
        iterfield=["input_image"],
    )

    try:
        inu_n4_final.inputs.rescale_intensities = True
    except ValueError:
        warn(
            "N4BiasFieldCorrection's --rescale-intensities option was added in ANTS 2.1.0 "
            f"({inu_n4_final.interface.version} found.) Please consider upgrading.",
            UserWarning,
        )

    # Apply mask
    apply_mask = pe.MapNode(ApplyMask(),
                            iterfield=["in_file"],
                            name="apply_mask")

    # fmt: off
    wf.connect([
        (inputnode, dil_brainmask, [("in_mask", "op1")]),
        (inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]),
        (inputnode, copy_xform_wm, [(("in_files", _pop), "hdr_file")]),
        (inputnode, pad_mask, [("in_mask", "op1")]),
        (inputnode, atropos, [("in_corrected", "intensity_images")]),
        (inputnode, inu_n4_final, [("in_files", "input_image")]),
        (inputnode, msk_conform, [(("in_files", _pop), "in_reference")]),
        (dil_brainmask, get_brainmask, [("output_image", "op1")]),
        (get_brainmask, atropos, [("output_image", "mask_image")]),
        (atropos, pad_segm, [("classified_image", "op1")]),
        (pad_segm, sel_labels, [("output_image", "in_segm")]),
        (sel_labels, get_wm, [("out_wm", "op1")]),
        (sel_labels, get_gm, [("out_gm", "op1")]),
        (get_gm, fill_gm, [("output_image", "op1")]),
        (get_gm, mult_gm, [("output_image", "first_input")]),
        (fill_gm, mult_gm, [("output_image", "second_input")]),
        (get_wm, relabel_wm, [("output_image", "first_input")]),
        (sel_labels, me_csf, [("out_csf", "op1")]),
        (mult_gm, add_gm, [("output_product_image", "op1")]),
        (me_csf, add_gm, [("output_image", "op2")]),
        (add_gm, relabel_gm, [("output_image", "first_input")]),
        (relabel_wm, add_gm_wm, [("output_product_image", "op1")]),
        (relabel_gm, add_gm_wm, [("output_product_image", "op2")]),
        (add_gm_wm, sel_labels2, [("output_image", "in_segm")]),
        (sel_labels2, add_7, [("out_wm", "op1"), ("out_gm", "op2")]),
        (add_7, me_7, [("output_image", "op1")]),
        (me_7, comp_7, [("output_image", "op1")]),
        (comp_7, md_7, [("output_image", "op1")]),
        (md_7, fill_7, [("output_image", "op1")]),
        (fill_7, add_7_2, [("output_image", "op1")]),
        (pad_mask, add_7_2, [("output_image", "op2")]),
        (add_7_2, md_7_2, [("output_image", "op1")]),
        (md_7_2, me_7_2, [("output_image", "op1")]),
        (me_7_2, depad_mask, [("output_image", "op1")]),
        (add_gm_wm, depad_segm, [("output_image", "op1")]),
        (relabel_wm, depad_wm, [("output_product_image", "op1")]),
        (relabel_gm, depad_gm, [("output_product_image", "op1")]),
        (sel_labels, depad_csf, [("out_csf", "op1")]),
        (depad_csf, merge_tpms, [("output_image", "in1")]),
        (depad_gm, merge_tpms, [("output_image", "in2")]),
        (depad_wm, merge_tpms, [("output_image", "in3")]),
        (depad_mask, msk_conform, [("output_image", "in_mask")]),
        (msk_conform, copy_xform, [("out", "out_mask")]),
        (depad_segm, copy_xform, [("output_image", "out_segm")]),
        (merge_tpms, copy_xform, [("out", "out_tpms")]),
        (atropos, sel_wm, [("posteriors", "inlist")]),
        (sel_wm, copy_xform_wm, [("out", "wm_map")]),
        (copy_xform_wm, inu_n4_final, [("wm_map", "weight_image")]),
        (inu_n4_final, copy_xform, [("output_image", "bias_corrected"),
                                    ("bias_image", "bias_image")]),
        (copy_xform, apply_mask, [("bias_corrected", "in_file"),
                                  ("out_mask", "in_mask")]),
        (apply_mask, outputnode, [("out_file", "out_file")]),
        (copy_xform, outputnode, [
            ("bias_corrected", "bias_corrected"),
            ("bias_image", "bias_image"),
            ("out_mask", "out_mask"),
            ("out_segm", "out_segm"),
            ("out_tpms", "out_tpms"),
        ]),
    ])
    # fmt: on

    if wm_prior:
        from nipype.algorithms.metrics import FuzzyOverlap

        def _argmax(in_dice):
            import numpy as np

            return np.argmax(in_dice)

        match_wm = pe.Node(
            niu.Function(function=_matchlen),
            name="match_wm",
            run_without_submitting=True,
        )
        overlap = pe.Node(FuzzyOverlap(),
                          name="overlap",
                          run_without_submitting=True)

        apply_wm_prior = pe.Node(niu.Function(function=_improd),
                                 name="apply_wm_prior")

        # fmt: off
        wf.disconnect([
            (copy_xform_wm, inu_n4_final, [("wm_map", "weight_image")]),
        ])
        wf.connect([
            (inputnode, apply_wm_prior, [("in_mask", "in_mask"),
                                         ("wm_prior", "op2")]),
            (inputnode, match_wm, [("wm_prior", "value")]),
            (atropos, match_wm, [("posteriors", "reference")]),
            (atropos, overlap, [("posteriors", "in_ref")]),
            (match_wm, overlap, [("out", "in_tst")]),
            (overlap, sel_wm, [(("class_fdi", _argmax), "index")]),
            (copy_xform_wm, apply_wm_prior, [("wm_map", "op1")]),
            (apply_wm_prior, inu_n4_final, [("out", "weight_image")]),
        ])
        # fmt: on
    return wf
Exemplo n.º 5
0
def BAWantsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix=''):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image :
           inputspec.ListOfPassiveImagesDictionaries :
           inputspec.interpolationMapping :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates :
    """
    TemplateBuildSingleIterationWF = pe.Workflow(
        name='antsRegistrationTemplateBuildSingleIterationWF_' +
        str(iterationPhasePrefix))

    inputSpec = pe.Node(
        interface=util.IdentityInterface(fields=[
            'ListOfImagesDictionaries',
            'registrationImageTypes',
            #'maskRegistrationImageType',
            'interpolationMapping',
            'fixed_image'
        ]),
        run_without_submitting=True,
        name='inputspec')
    ## HACK: TODO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0
    ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: TODO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: TODO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['template', 'transforms_list', 'passive_deformed_templates']),
                         run_without_submitting=True,
                         name='outputspec')

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(interface=Registration(),
                           name='BeginANTS',
                           iterfield=['moving_image'])
    BeginANTS.inputs.dimension = 3
    """ This is the recommended set of parameters from the ANTS developers """
    BeginANTS.inputs.output_transform_prefix = str(
        iterationPhasePrefix) + '_tfm'
    BeginANTS.inputs.transforms = ["Rigid", "Affine", "SyN", "SyN", "SyN"]
    BeginANTS.inputs.transform_parameters = [[0.1], [0.1], [0.1, 3.0, 0.0],
                                             [0.1, 3.0, 0.0], [0.1, 3.0, 0.0]]
    BeginANTS.inputs.metric = ['MI', 'MI', 'CC', 'CC', 'CC']
    BeginANTS.inputs.sampling_strategy = [
        'Regular', 'Regular', None, None, None
    ]
    BeginANTS.inputs.sampling_percentage = [0.27, 0.27, 1.0, 1.0, 1.0]
    BeginANTS.inputs.metric_weight = [1.0, 1.0, 1.0, 1.0, 1.0]
    BeginANTS.inputs.radius_or_number_of_bins = [32, 32, 4, 4, 4]
    BeginANTS.inputs.number_of_iterations = [[1000, 1000, 1000, 1000],
                                             [1000, 1000, 1000, 1000],
                                             [1000, 250], [140], [25]]
    BeginANTS.inputs.convergence_threshold = [5e-8, 5e-8, 5e-7, 5e-6, 5e-5]
    BeginANTS.inputs.convergence_window_size = [10, 10, 10, 10, 10]
    BeginANTS.inputs.use_histogram_matching = [True, True, True, True, True]
    BeginANTS.inputs.shrink_factors = [[8, 4, 2, 1], [8, 4, 2, 1], [8, 4], [2],
                                       [1]]
    BeginANTS.inputs.smoothing_sigmas = [[3, 2, 1, 0], [3, 2, 1, 0], [3, 2],
                                         [1], [0]]
    BeginANTS.inputs.sigma_units = ["vox", "vox", "vox", "vox", "vox"]
    BeginANTS.inputs.use_estimate_learning_rate_once = [
        False, False, False, False, False
    ]
    BeginANTS.inputs.write_composite_transform = True
    BeginANTS.inputs.collapse_output_transforms = False
    BeginANTS.inputs.initialize_transforms_per_stage = True
    BeginANTS.inputs.winsorize_lower_quantile = 0.01
    BeginANTS.inputs.winsorize_upper_quantile = 0.99
    BeginANTS.inputs.output_warped_image = 'atlas2subject.nii.gz'
    BeginANTS.inputs.output_inverse_warped_image = 'subject2atlas.nii.gz'
    BeginANTS.inputs.save_state = 'SavedBeginANTSSyNState.h5'
    BeginANTS.inputs.float = True

    GetMovingImagesNode = pe.Node(interface=util.Function(
        function=GetMovingImages,
        input_names=[
            'ListOfImagesDictionaries', 'registrationImageTypes',
            'interpolationMapping'
        ],
        output_names=['moving_images', 'moving_interpolation_type']),
                                  run_without_submitting=True,
                                  name='99_GetMovingImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetMovingImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetMovingImagesNode,
                                           'registrationImageTypes')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           GetMovingImagesNode,
                                           'interpolationMapping')

    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', BeginANTS,
                                           'moving_image')
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_interpolation_type',
                                           BeginANTS, 'interpolation')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS,
                                           'fixed_image')

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=['transforms', 'input_image'],
        #iterfield=['transforms', 'invert_transform_flags', 'input_image'],
        name='wimtdeformed')
    wimtdeformed.inputs.interpolation = 'Linear'
    wimtdeformed.default_value = 0
    # HACK: Should try using forward_composite_transform
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           wimtdeformed, 'transforms')
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags')
    ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms
    #wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False]
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', wimtdeformed,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image',
                                           wimtdeformed, 'reference_image')

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(),
                                name='AvgDeformedImages')
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = str(
        iterationPhasePrefix) + '.nii.gz'
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image",
                                           AvgDeformedImages, 'images')

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(interface=AverageAffineTransform(),
                                 name='AvgAffineTransform')
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str(
        iterationPhasePrefix) + '_Affine.h5'

    SplitCompositeTransform = pe.MapNode(
        interface=util.Function(
            function=SplitCompositeToComponentTransforms,
            input_names=['composite_transform_as_list'],
            output_names=['affine_component_list', 'warp_component_list']),
        iterfield=['composite_transform_as_list'],
        run_without_submitting=True,
        name='99_SplitCompositeTransform')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           SplitCompositeTransform,
                                           'composite_transform_as_list')
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'composite_transform_as_list')
    TemplateBuildSingleIterationWF.connect(SplitCompositeTransform,
                                           'affine_component_list',
                                           AvgAffineTransform, 'transforms')

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages')
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = str(
        iterationPhasePrefix) + 'warp.nii.gz'
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(SplitCompositeTransform,
                                           'warp_component_list',
                                           AvgWarpImages, 'images')

    ## Now average the images together
    ## TODO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(interface=MultiplyImages(),
                                    name='GradientStepWarpImage')
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str(
        iterationPhasePrefix) + '_warp.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgWarpImages,
                                           'output_average_image',
                                           GradientStepWarpImage,
                                           'first_input')

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(interface=ApplyTransforms(),
                                  name='UpdateTemplateShape')
    UpdateTemplateShape.inputs.invert_transform_flags = [True]
    UpdateTemplateShape.inputs.interpolation = 'Linear'
    UpdateTemplateShape.default_value = 0

    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           UpdateTemplateShape,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect([
        (AvgAffineTransform, UpdateTemplateShape,
         [(('affine_transform', makeListOfOneElement), 'transforms')]),
    ])
    TemplateBuildSingleIterationWF.connect(GradientStepWarpImage,
                                           'output_product_image',
                                           UpdateTemplateShape, 'input_image')

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=MakeTransformListWithGradientWarps,
            input_names=['averageAffineTranform', 'gradientStepWarp'],
            output_names=['TransformListWithGradientWarps']),
        run_without_submitting=True,
        name='99_MakeTransformListWithGradientWarps')
    ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform, 'affine_transform',
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'averageAffineTranform')
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape, 'output_image',
        ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp')

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate')
    ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear'
    ReshapeAverageImageWithShapeUpdate.default_value = 0
    ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate,
        'transforms')
    TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate,
                                           'output_image', outputSpec,
                                           'template')

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(function=FlattenTransformAndImagesList,
                 input_names=[
                     'ListOfPassiveImagesDictionaries', 'transforms',
                     'interpolationMapping', 'invert_transform_flags'
                 ],
                 output_names=[
                     'flattened_images', 'flattened_transforms',
                     'flattened_invert_transform_flags',
                     'flattened_image_nametypes',
                     'flattened_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList")

    GetPassiveImagesNode = pe.Node(interface=util.Function(
        function=GetPassiveImages,
        input_names=['ListOfImagesDictionaries', 'registrationImageTypes'],
        output_names=['ListOfPassiveImagesDictionaries']),
                                   run_without_submitting=True,
                                   name='99_GetPassiveImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetPassiveImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetPassiveImagesNode,
                                           'registrationImageTypes')

    TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode,
                                           'ListOfPassiveImagesDictionaries',
                                           FlattenTransformAndImagesListNode,
                                           'ListOfPassiveImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           FlattenTransformAndImagesListNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           FlattenTransformAndImagesListNode,
                                           'transforms')
    ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False]
    ## TODO: Please check of invert_transform_flags has a fixed number.
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags')
    wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(),
                                     iterfield=[
                                         'transforms',
                                         'invert_transform_flags',
                                         'input_image', 'interpolation'
                                     ],
                                     name='wimtPassivedeformed')
    wimtPassivedeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           wimtPassivedeformed,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_interpolation_type',
                                           wimtPassivedeformed,
                                           'interpolation')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_images',
                                           wimtPassivedeformed, 'input_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_transforms',
                                           wimtPassivedeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_invert_transform_flags',
                                           wimtPassivedeformed,
                                           'invert_transform_flags')

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(function=RenestDeformedPassiveImages,
                 input_names=[
                     'deformedPassiveImages', 'flattened_image_nametypes',
                     'interpolationMapping'
                 ],
                 output_names=[
                     'nested_imagetype_list', 'outputAverageImageName_list',
                     'image_type_list', 'nested_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages")
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           RenestDeformedPassiveImagesNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image',
                                           RenestDeformedPassiveImagesNode,
                                           'deformedPassiveImages')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_image_nametypes',
                                           RenestDeformedPassiveImagesNode,
                                           'flattened_image_nametypes')
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=['images', 'output_average_image'],
        name='AvgDeformedPassiveImages')
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "nested_imagetype_list",
                                           AvgDeformedPassiveImages, 'images')
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "outputAverageImageName_list",
                                           AvgDeformedPassiveImages,
                                           'output_average_image')

    ## -- TODO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=[
            'input_image', 'reference_image', 'output_image', 'interpolation'
        ],
        name='ReshapeAveragePassiveImageWithShapeUpdate')
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'nested_interpolation_type',
        ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation')
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'outputAverageImageName_list',
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'input_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps',
        ReshapeAveragePassiveImageWithShapeUpdate, 'transforms')
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec,
        'passive_deformed_templates')

    return TemplateBuildSingleIterationWF
Exemplo n.º 6
0
def ANTSTemplateBuildSingleIterationWF(iterationPhasePrefix=''):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image : 
           inputspec.ListOfPassiveImagesDictionaries :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates : 
    """

    TemplateBuildSingleIterationWF = pe.Workflow(
        name='ANTSTemplateBuildSingleIterationWF_' +
        str(str(iterationPhasePrefix)))

    inputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['images', 'fixed_image', 'ListOfPassiveImagesDictionaries']),
                        run_without_submitting=True,
                        name='inputspec')
    ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: TODO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: TODO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['template', 'transforms_list', 'passive_deformed_templates']),
                         run_without_submitting=True,
                         name='outputspec')

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(interface=ANTS(),
                           name='BeginANTS',
                           iterfield=['moving_image'])
    BeginANTS.inputs.dimension = 3
    BeginANTS.inputs.output_transform_prefix = str(
        iterationPhasePrefix) + '_tfm'
    BeginANTS.inputs.metric = ['CC']
    BeginANTS.inputs.metric_weight = [1.0]
    BeginANTS.inputs.radius = [5]
    BeginANTS.inputs.transformation_model = 'SyN'
    BeginANTS.inputs.gradient_step_length = 0.25
    BeginANTS.inputs.number_of_iterations = [50, 35, 15]
    BeginANTS.inputs.number_of_affine_iterations = [
        10000, 10000, 10000, 10000, 10000
    ]
    BeginANTS.inputs.use_histogram_matching = True
    BeginANTS.inputs.mi_option = [32, 16000]
    BeginANTS.inputs.regularization = 'Gauss'
    BeginANTS.inputs.regularization_gradient_field_sigma = 3
    BeginANTS.inputs.regularization_deformation_field_sigma = 0
    TemplateBuildSingleIterationWF.connect(inputSpec, 'images', BeginANTS,
                                           'moving_image')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS,
                                           'fixed_image')

    MakeTransformsLists = pe.Node(interface=util.Function(
        function=MakeListsOfTransformLists,
        input_names=['warpTransformList', 'AffineTransformList'],
        output_names=['out']),
                                  run_without_submitting=True,
                                  name='MakeTransformsLists')
    MakeTransformsLists.inputs.ignore_exception = True
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'warp_transform',
                                           MakeTransformsLists,
                                           'warpTransformList')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'affine_transform',
                                           MakeTransformsLists,
                                           'AffineTransformList')

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=WarpImageMultiTransform(),
        iterfield=['transformation_series', 'input_image'],
        name='wimtdeformed')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'images', wimtdeformed,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(MakeTransformsLists, 'out',
                                           wimtdeformed,
                                           'transformation_series')

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(),
                                name='AvgDeformedImages')
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = str(
        iterationPhasePrefix) + '.nii.gz'
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image",
                                           AvgDeformedImages, 'images')

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(interface=AverageAffineTransform(),
                                 name='AvgAffineTransform')
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str(
        iterationPhasePrefix) + '_Affine.mat'
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'affine_transform',
                                           AvgAffineTransform, 'transforms')

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages')
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = str(
        iterationPhasePrefix) + 'warp.nii.gz'
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'warp_transform',
                                           AvgWarpImages, 'images')

    ## Now average the images together
    ## TODO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(interface=MultiplyImages(),
                                    name='GradientStepWarpImage')
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str(
        iterationPhasePrefix) + '_warp.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgWarpImages,
                                           'output_average_image',
                                           GradientStepWarpImage,
                                           'first_input')

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(interface=WarpImageMultiTransform(),
                                  name='UpdateTemplateShape')
    UpdateTemplateShape.inputs.invert_affine = [1]
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           UpdateTemplateShape,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(AvgAffineTransform,
                                           'affine_transform',
                                           UpdateTemplateShape,
                                           'transformation_series')
    TemplateBuildSingleIterationWF.connect(GradientStepWarpImage,
                                           'output_product_image',
                                           UpdateTemplateShape, 'input_image')

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=MakeTransformListWithGradientWarps,
            input_names=['averageAffineTranform', 'gradientStepWarp'],
            output_names=['TransformListWithGradientWarps']),
        run_without_submitting=True,
        name='MakeTransformListWithGradientWarps')
    ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform, 'affine_transform',
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'averageAffineTranform')
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape, 'output_image',
        ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp')

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=WarpImageMultiTransform(),
        name='ReshapeAverageImageWithShapeUpdate')
    ReshapeAverageImageWithShapeUpdate.inputs.invert_affine = [1]
    ReshapeAverageImageWithShapeUpdate.inputs.out_postfix = '_Reshaped'
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate,
        'transformation_series')
    TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate,
                                           'output_image', outputSpec,
                                           'template')

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(function=FlattenTransformAndImagesList,
                 input_names=[
                     'ListOfPassiveImagesDictionaries', 'transformation_series'
                 ],
                 output_names=[
                     'flattened_images', 'flattened_transforms',
                     'flattened_image_nametypes'
                 ]),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList")
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfPassiveImagesDictionaries',
                                           FlattenTransformAndImagesListNode,
                                           'ListOfPassiveImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(MakeTransformsLists, 'out',
                                           FlattenTransformAndImagesListNode,
                                           'transformation_series')
    wimtPassivedeformed = pe.MapNode(
        interface=WarpImageMultiTransform(),
        iterfield=['transformation_series', 'input_image'],
        name='wimtPassivedeformed')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           wimtPassivedeformed,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_images',
                                           wimtPassivedeformed, 'input_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_transforms',
                                           wimtPassivedeformed,
                                           'transformation_series')

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(
            function=RenestDeformedPassiveImages,
            input_names=['deformedPassiveImages', 'flattened_image_nametypes'],
            output_names=[
                'nested_imagetype_list', 'outputAverageImageName_list',
                'image_type_list'
            ]),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages")
    TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image',
                                           RenestDeformedPassiveImagesNode,
                                           'deformedPassiveImages')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_image_nametypes',
                                           RenestDeformedPassiveImagesNode,
                                           'flattened_image_nametypes')
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=['images', 'output_average_image'],
        name='AvgDeformedPassiveImages')
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "nested_imagetype_list",
                                           AvgDeformedPassiveImages, 'images')
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "outputAverageImageName_list",
                                           AvgDeformedPassiveImages,
                                           'output_average_image')

    ## -- TODO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=WarpImageMultiTransform(),
        iterfield=['input_image', 'reference_image', 'out_postfix'],
        name='ReshapeAveragePassiveImageWithShapeUpdate')
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_affine = [1]
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, "image_type_list",
        ReshapeAveragePassiveImageWithShapeUpdate, 'out_postfix')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'input_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps',
        ReshapeAveragePassiveImageWithShapeUpdate, 'transformation_series')
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec,
        'passive_deformed_templates')

    return TemplateBuildSingleIterationWF
def baw_ants_registration_template_build_single_iteration_wf(
    iterationPhasePrefix, CLUSTER_QUEUE, CLUSTER_QUEUE_LONG
):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image :
           inputspec.ListOfPassiveImagesDictionaries :
           inputspec.interpolationMapping :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates :
    """
    TemplateBuildSingleIterationWF = pe.Workflow(
        name="antsRegistrationTemplateBuildSingleIterationWF_"
        + str(iterationPhasePrefix)
    )

    inputSpec = pe.Node(
        interface=util.IdentityInterface(
            fields=[
                "ListOfImagesDictionaries",
                "registrationImageTypes",
                # 'maskRegistrationImageType',
                "interpolationMapping",
                "fixed_image",
            ]
        ),
        run_without_submitting=True,
        name="inputspec",
    )
    ## HACK: INFO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0
    ## HACK: INFO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: INFO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: INFO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(
        interface=util.IdentityInterface(
            fields=["template", "transforms_list", "passive_deformed_templates"]
        ),
        run_without_submitting=True,
        name="outputspec",
    )

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(
        interface=Registration(), name="BeginANTS", iterfield=["moving_image"]
    )
    # SEE template.py many_cpu_BeginANTS_options_dictionary = {'qsub_args': modify_qsub_args(CLUSTER_QUEUE,4,2,8), 'overwrite': True}
    ## This is set in the template.py file BeginANTS.plugin_args = BeginANTS_cpu_sge_options_dictionary
    common_ants_registration_settings(
        antsRegistrationNode=BeginANTS,
        registrationTypeDescription="SixStageAntsRegistrationT1Only",
        output_transform_prefix=str(iterationPhasePrefix) + "_tfm",
        output_warped_image="atlas2subject.nii.gz",
        output_inverse_warped_image="subject2atlas.nii.gz",
        save_state="SavedantsRegistrationNodeSyNState.h5",
        invert_initial_moving_transform=False,
        initial_moving_transform=None,
    )

    GetMovingImagesNode = pe.Node(
        interface=util.Function(
            function=get_moving_images,
            input_names=[
                "ListOfImagesDictionaries",
                "registrationImageTypes",
                "interpolationMapping",
            ],
            output_names=["moving_images", "moving_interpolation_type"],
        ),
        run_without_submitting=True,
        name="99_GetMovingImagesNode",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "ListOfImagesDictionaries",
        GetMovingImagesNode,
        "ListOfImagesDictionaries",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "registrationImageTypes",
        GetMovingImagesNode,
        "registrationImageTypes",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec, "interpolationMapping", GetMovingImagesNode, "interpolationMapping"
    )

    TemplateBuildSingleIterationWF.connect(
        GetMovingImagesNode, "moving_images", BeginANTS, "moving_image"
    )
    TemplateBuildSingleIterationWF.connect(
        GetMovingImagesNode, "moving_interpolation_type", BeginANTS, "interpolation"
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec, "fixed_image", BeginANTS, "fixed_image"
    )

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=["transforms", "input_image"],
        # iterfield=['transforms', 'invert_transform_flags', 'input_image'],
        name="wimtdeformed",
    )
    wimtdeformed.inputs.interpolation = "Linear"
    wimtdeformed.default_value = 0
    # HACK: Should try using forward_composite_transform
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(
        BeginANTS, "composite_transform", wimtdeformed, "transforms"
    )
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags')
    ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms
    # wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False]
    TemplateBuildSingleIterationWF.connect(
        GetMovingImagesNode, "moving_images", wimtdeformed, "input_image"
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec, "fixed_image", wimtdeformed, "reference_image"
    )

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(), name="AvgDeformedImages")
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = (
        str(iterationPhasePrefix) + ".nii.gz"
    )
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(
        wimtdeformed, "output_image", AvgDeformedImages, "images"
    )

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(
        interface=AverageAffineTransform(), name="AvgAffineTransform"
    )
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = (
        "Avererage_" + str(iterationPhasePrefix) + "_Affine.h5"
    )

    SplitCompositeTransform = pe.MapNode(
        interface=util.Function(
            function=split_composite_to_component_transform,
            input_names=["transformFilename"],
            output_names=["affine_component_list", "warp_component_list"],
        ),
        iterfield=["transformFilename"],
        run_without_submitting=True,
        name="99_SplitCompositeTransform",
    )
    TemplateBuildSingleIterationWF.connect(
        BeginANTS, "composite_transform", SplitCompositeTransform, "transformFilename"
    )
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'transformFilename')
    TemplateBuildSingleIterationWF.connect(
        SplitCompositeTransform,
        "affine_component_list",
        AvgAffineTransform,
        "transforms",
    )

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name="AvgWarpImages")
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = (
        str(iterationPhasePrefix) + "warp.nii.gz"
    )
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(
        SplitCompositeTransform, "warp_component_list", AvgWarpImages, "images"
    )

    ## Now average the images together
    ## INFO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(
        interface=MultiplyImages(), name="GradientStepWarpImage"
    )
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = (
        "GradientStep0.25_" + str(iterationPhasePrefix) + "_warp.nii.gz"
    )
    TemplateBuildSingleIterationWF.connect(
        AvgWarpImages, "output_average_image", GradientStepWarpImage, "first_input"
    )

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(
        interface=ApplyTransforms(), name="UpdateTemplateShape"
    )
    UpdateTemplateShape.inputs.invert_transform_flags = [True]
    UpdateTemplateShape.inputs.interpolation = "Linear"
    UpdateTemplateShape.default_value = 0

    TemplateBuildSingleIterationWF.connect(
        AvgDeformedImages,
        "output_average_image",
        UpdateTemplateShape,
        "reference_image",
    )
    TemplateBuildSingleIterationWF.connect(
        [
            (
                AvgAffineTransform,
                UpdateTemplateShape,
                [(("affine_transform", make_list_of_one_element), "transforms")],
            )
        ]
    )
    TemplateBuildSingleIterationWF.connect(
        GradientStepWarpImage,
        "output_product_image",
        UpdateTemplateShape,
        "input_image",
    )

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=make_transform_list_with_gradient_warps,
            input_names=["averageAffineTranform", "gradientStepWarp"],
            output_names=["TransformListWithGradientWarps"],
        ),
        run_without_submitting=True,
        name="99_MakeTransformListWithGradientWarps",
    )
    # ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform,
        "affine_transform",
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        "averageAffineTranform",
    )
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape,
        "output_image",
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        "gradientStepWarp",
    )

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=ApplyTransforms(), name="ReshapeAverageImageWithShapeUpdate"
    )
    ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [
        True,
        False,
        False,
        False,
        False,
    ]
    ReshapeAverageImageWithShapeUpdate.inputs.interpolation = "Linear"
    ReshapeAverageImageWithShapeUpdate.default_value = 0
    ReshapeAverageImageWithShapeUpdate.inputs.output_image = (
        "ReshapeAverageImageWithShapeUpdate.nii.gz"
    )
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedImages,
        "output_average_image",
        ReshapeAverageImageWithShapeUpdate,
        "input_image",
    )
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedImages,
        "output_average_image",
        ReshapeAverageImageWithShapeUpdate,
        "reference_image",
    )
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        "TransformListWithGradientWarps",
        ReshapeAverageImageWithShapeUpdate,
        "transforms",
    )
    TemplateBuildSingleIterationWF.connect(
        ReshapeAverageImageWithShapeUpdate, "output_image", outputSpec, "template"
    )

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(
            function=flatten_transform_and_images_list,
            input_names=[
                "ListOfPassiveImagesDictionaries",
                "transforms",
                "interpolationMapping",
                "invert_transform_flags",
            ],
            output_names=[
                "flattened_images",
                "flattened_transforms",
                "flattened_invert_transform_flags",
                "flattened_image_nametypes",
                "flattened_interpolation_type",
            ],
        ),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList",
    )

    GetPassiveImagesNode = pe.Node(
        interface=util.Function(
            function=get_passive_images,
            input_names=["ListOfImagesDictionaries", "registrationImageTypes"],
            output_names=["ListOfPassiveImagesDictionaries"],
        ),
        run_without_submitting=True,
        name="99_GetPassiveImagesNode",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "ListOfImagesDictionaries",
        GetPassiveImagesNode,
        "ListOfImagesDictionaries",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "registrationImageTypes",
        GetPassiveImagesNode,
        "registrationImageTypes",
    )

    TemplateBuildSingleIterationWF.connect(
        GetPassiveImagesNode,
        "ListOfPassiveImagesDictionaries",
        FlattenTransformAndImagesListNode,
        "ListOfPassiveImagesDictionaries",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "interpolationMapping",
        FlattenTransformAndImagesListNode,
        "interpolationMapping",
    )
    TemplateBuildSingleIterationWF.connect(
        BeginANTS,
        "composite_transform",
        FlattenTransformAndImagesListNode,
        "transforms",
    )
    ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False]
    ## INFO: Please check of invert_transform_flags has a fixed number.
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags')
    wimtPassivedeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=[
            "transforms",
            "invert_transform_flags",
            "input_image",
            "interpolation",
        ],
        name="wimtPassivedeformed",
    )
    wimtPassivedeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedImages,
        "output_average_image",
        wimtPassivedeformed,
        "reference_image",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_interpolation_type",
        wimtPassivedeformed,
        "interpolation",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_images",
        wimtPassivedeformed,
        "input_image",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_transforms",
        wimtPassivedeformed,
        "transforms",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_invert_transform_flags",
        wimtPassivedeformed,
        "invert_transform_flags",
    )

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(
            function=renest_deformed_passive_images,
            input_names=[
                "deformedPassiveImages",
                "flattened_image_nametypes",
                "interpolationMapping",
            ],
            output_names=[
                "nested_imagetype_list",
                "outputAverageImageName_list",
                "image_type_list",
                "nested_interpolation_type",
            ],
        ),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "interpolationMapping",
        RenestDeformedPassiveImagesNode,
        "interpolationMapping",
    )
    TemplateBuildSingleIterationWF.connect(
        wimtPassivedeformed,
        "output_image",
        RenestDeformedPassiveImagesNode,
        "deformedPassiveImages",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_image_nametypes",
        RenestDeformedPassiveImagesNode,
        "flattened_image_nametypes",
    )
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=["images", "output_average_image"],
        name="AvgDeformedPassiveImages",
    )
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode,
        "nested_imagetype_list",
        AvgDeformedPassiveImages,
        "images",
    )
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode,
        "outputAverageImageName_list",
        AvgDeformedPassiveImages,
        "output_average_image",
    )

    ## -- INFO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=["input_image", "reference_image", "output_image", "interpolation"],
        name="ReshapeAveragePassiveImageWithShapeUpdate",
    )
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [
        True,
        False,
        False,
        False,
        False,
    ]
    ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode,
        "nested_interpolation_type",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "interpolation",
    )
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode,
        "outputAverageImageName_list",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "output_image",
    )
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages,
        "output_average_image",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "input_image",
    )
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages,
        "output_average_image",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "reference_image",
    )
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        "TransformListWithGradientWarps",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "transforms",
    )
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate,
        "output_image",
        outputSpec,
        "passive_deformed_templates",
    )

    return TemplateBuildSingleIterationWF
Exemplo n.º 8
0
def init_atropos_wf(
        name="atropos_wf",
        use_random_seed=True,
        omp_nthreads=None,
        mem_gb=3.0,
        padding=10,
        in_segmentation_model=tuple(ATROPOS_MODELS["T1w"].values()),
):
    """
    Create an ANTs' ATROPOS workflow for brain tissue segmentation.

    Implements supersteps 6 and 7 of ``antsBrainExtraction.sh``,
    which refine the mask previously computed with the spatial
    normalization to the template.

    Workflow Graph
        .. workflow::
            :graph2use: orig
            :simple_form: yes

            from niworkflows.anat.ants import init_atropos_wf
            wf = init_atropos_wf()

    Parameters
    ----------
    use_random_seed : bool
        Whether ATROPOS should generate a random seed based on the
        system's clock
    omp_nthreads : int
        Maximum number of threads an individual process may use
    mem_gb : float
        Estimated peak memory consumption of the most hungry nodes
        in the workflow
    padding : int
        Pad images with zeros before processing
    in_segmentation_model : tuple
        A k-means segmentation is run to find gray or white matter
        around the edge of the initial brain mask warped from the
        template.
        This produces a segmentation image with :math:`$K$` classes,
        ordered by mean intensity in increasing order.
        With this option, you can control  :math:`$K$` and tell the script which
        classes represent CSF, gray and white matter.
        Format (K, csfLabel, gmLabel, wmLabel).
        Examples:
        ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default),
        ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1,
        ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2,
        ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3.
    name : str, optional
        Workflow name (default: "atropos_wf").

    Inputs
    ------
    in_files : list
        :abbr:`INU (intensity non-uniformity)`-corrected files.
    in_mask : str
        Brain mask calculated previously.

    Outputs
    -------
    out_mask : str
        Refined brain mask
    out_segm : str
        Output segmentation
    out_tpms : str
        Output :abbr:`TPMs (tissue probability maps)`


    """
    wf = pe.Workflow(name)

    inputnode = pe.Node(
        niu.IdentityInterface(
            fields=["in_files", "in_mask", "in_mask_dilated"]),
        name="inputnode",
    )
    outputnode = pe.Node(
        niu.IdentityInterface(fields=["out_mask", "out_segm", "out_tpms"]),
        name="outputnode",
    )

    copy_xform = pe.Node(
        CopyXForm(fields=["out_mask", "out_segm", "out_tpms"]),
        name="copy_xform",
        run_without_submitting=True,
    )

    # Run atropos (core node)
    atropos = pe.Node(
        Atropos(
            dimension=3,
            initialization="KMeans",
            number_of_tissue_classes=in_segmentation_model[0],
            n_iterations=3,
            convergence_threshold=0.0,
            mrf_radius=[1, 1, 1],
            mrf_smoothing_factor=0.1,
            likelihood_model="Gaussian",
            use_random_seed=use_random_seed,
        ),
        name="01_atropos",
        n_procs=omp_nthreads,
        mem_gb=mem_gb,
    )

    # massage outputs
    pad_segm = pe.Node(ImageMath(operation="PadImage", op2="%d" % padding),
                       name="02_pad_segm")
    pad_mask = pe.Node(ImageMath(operation="PadImage", op2="%d" % padding),
                       name="03_pad_mask")

    # Split segmentation in binary masks
    sel_labels = pe.Node(
        niu.Function(function=_select_labels,
                     output_names=["out_wm", "out_gm", "out_csf"]),
        name="04_sel_labels",
    )
    sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:]))

    # Select largest components (GM, WM)
    # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM}
    get_wm = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="05_get_wm")
    get_gm = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="06_get_gm")

    # Fill holes and calculate intersection
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM}
    fill_gm = pe.Node(ImageMath(operation="FillHoles", op2="2"),
                      name="07_fill_gm")
    mult_gm = pe.Node(
        MultiplyImages(dimension=3, output_product_image="08_mult_gm.nii.gz"),
        name="08_mult_gm",
    )

    # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM}
    # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10
    relabel_wm = pe.Node(
        MultiplyImages(
            dimension=3,
            second_input=in_segmentation_model[-1],
            output_product_image="09_relabel_wm.nii.gz",
        ),
        name="09_relabel_wm",
    )
    me_csf = pe.Node(ImageMath(operation="ME", op2="10"), name="10_me_csf")

    # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP}
    # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM}
    # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM}
    add_gm = pe.Node(ImageMath(operation="addtozero"), name="11_add_gm")
    relabel_gm = pe.Node(
        MultiplyImages(
            dimension=3,
            second_input=in_segmentation_model[-2],
            output_product_image="12_relabel_gm.nii.gz",
        ),
        name="12_relabel_gm",
    )
    add_gm_wm = pe.Node(ImageMath(operation="addtozero"), name="13_add_gm_wm")

    # Superstep 7
    # Split segmentation in binary masks
    sel_labels2 = pe.Node(
        niu.Function(function=_select_labels,
                     output_names=["out_gm", "out_wm"]),
        name="14_sel_labels2",
    )
    sel_labels2.inputs.labels = in_segmentation_model[2:]

    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP}
    add_7 = pe.Node(ImageMath(operation="addtozero"), name="15_add_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2
    me_7 = pe.Node(ImageMath(operation="ME", op2="2"), name="16_me_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK}
    comp_7 = pe.Node(ImageMath(operation="GetLargestComponent"),
                     name="17_comp_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4
    md_7 = pe.Node(ImageMath(operation="MD", op2="4"), name="18_md_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2
    fill_7 = pe.Node(ImageMath(operation="FillHoles", op2="2"),
                     name="19_fill_7")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \
    # ${EXTRACTION_MASK_PRIOR_WARPED}
    add_7_2 = pe.Node(ImageMath(operation="addtozero"), name="20_add_7_2")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5
    md_7_2 = pe.Node(ImageMath(operation="MD", op2="5"), name="21_md_7_2")
    # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5
    me_7_2 = pe.Node(ImageMath(operation="ME", op2="5"), name="22_me_7_2")

    # De-pad
    depad_mask = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                         name="23_depad_mask")
    depad_segm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                         name="24_depad_segm")
    depad_gm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                       name="25_depad_gm")
    depad_wm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                       name="26_depad_wm")
    depad_csf = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding),
                        name="27_depad_csf")

    msk_conform = pe.Node(niu.Function(function=_conform_mask),
                          name="msk_conform")
    merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]),
                         name="merge_tpms")
    # fmt: off
    wf.connect([
        (inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]),
        (inputnode, pad_mask, [("in_mask", "op1")]),
        (inputnode, atropos, [
            ("in_files", "intensity_images"),
            ("in_mask_dilated", "mask_image"),
        ]),
        (inputnode, msk_conform, [(("in_files", _pop), "in_reference")]),
        (atropos, pad_segm, [("classified_image", "op1")]),
        (pad_segm, sel_labels, [("output_image", "in_segm")]),
        (sel_labels, get_wm, [("out_wm", "op1")]),
        (sel_labels, get_gm, [("out_gm", "op1")]),
        (get_gm, fill_gm, [("output_image", "op1")]),
        (get_gm, mult_gm, [("output_image", "first_input")]),
        (fill_gm, mult_gm, [("output_image", "second_input")]),
        (get_wm, relabel_wm, [("output_image", "first_input")]),
        (sel_labels, me_csf, [("out_csf", "op1")]),
        (mult_gm, add_gm, [("output_product_image", "op1")]),
        (me_csf, add_gm, [("output_image", "op2")]),
        (add_gm, relabel_gm, [("output_image", "first_input")]),
        (relabel_wm, add_gm_wm, [("output_product_image", "op1")]),
        (relabel_gm, add_gm_wm, [("output_product_image", "op2")]),
        (add_gm_wm, sel_labels2, [("output_image", "in_segm")]),
        (sel_labels2, add_7, [("out_wm", "op1"), ("out_gm", "op2")]),
        (add_7, me_7, [("output_image", "op1")]),
        (me_7, comp_7, [("output_image", "op1")]),
        (comp_7, md_7, [("output_image", "op1")]),
        (md_7, fill_7, [("output_image", "op1")]),
        (fill_7, add_7_2, [("output_image", "op1")]),
        (pad_mask, add_7_2, [("output_image", "op2")]),
        (add_7_2, md_7_2, [("output_image", "op1")]),
        (md_7_2, me_7_2, [("output_image", "op1")]),
        (me_7_2, depad_mask, [("output_image", "op1")]),
        (add_gm_wm, depad_segm, [("output_image", "op1")]),
        (relabel_wm, depad_wm, [("output_product_image", "op1")]),
        (relabel_gm, depad_gm, [("output_product_image", "op1")]),
        (sel_labels, depad_csf, [("out_csf", "op1")]),
        (depad_csf, merge_tpms, [("output_image", "in1")]),
        (depad_gm, merge_tpms, [("output_image", "in2")]),
        (depad_wm, merge_tpms, [("output_image", "in3")]),
        (depad_mask, msk_conform, [("output_image", "in_mask")]),
        (msk_conform, copy_xform, [("out", "out_mask")]),
        (depad_segm, copy_xform, [("output_image", "out_segm")]),
        (merge_tpms, copy_xform, [("out", "out_tpms")]),
        (copy_xform, outputnode, [
            ("out_mask", "out_mask"),
            ("out_segm", "out_segm"),
            ("out_tpms", "out_tpms"),
        ]),
    ])
    # fmt: on
    return wf
def ANTs_cortical_thickness(subject_list, directory):

    #==============================================================
    # Loading required packages
    import nipype.interfaces.io as nio
    import nipype.pipeline.engine as pe
    import nipype.interfaces.utility as util
    import own_nipype
    from nipype.interfaces.ants.segmentation import antsCorticalThickness
    from nipype.interfaces.ants import ApplyTransforms
    from nipype.interfaces.ants import MultiplyImages
    from nipype.interfaces.utility import Function
    from nipype.interfaces.ants.visualization import ConvertScalarImageToRGB
    from nipype.interfaces.ants.visualization import CreateTiledMosaic
    from nipype.interfaces.utility import Select
    from own_nipype import GM_DENSITY
    from nipype import SelectFiles
    import os

    #====================================
    # Defining the nodes for the workflow

    # Getting the subject ID
    infosource = pe.Node(
        interface=util.IdentityInterface(fields=['subject_id']),
        name='infosource')
    infosource.iterables = ('subject_id', subject_list)

    # Getting the relevant diffusion-weighted data
    templates = dict(
        T1=
        '/imaging/jb07/CALM/CALM_BIDS/{subject_id}/anat/{subject_id}_T1w.nii.gz'
    )

    selectfiles = pe.Node(SelectFiles(templates), name="selectfiles")
    selectfiles.inputs.base_directory = os.path.abspath(directory)

    # Rigid alignment with the template space
    T1_rigid_quickSyN = pe.Node(interface=own_nipype.ants_QuickSyN(
        image_dimensions=3, transform_type='r'),
                                name='T1_rigid_quickSyN')
    T1_rigid_quickSyN.inputs.fixed_image = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0.nii.gz'

    # Cortical thickness calculation
    corticalthickness = pe.Node(interface=antsCorticalThickness(),
                                name='corticalthickness')
    corticalthickness.inputs.brain_probability_mask = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0_BrainCerebellumProbabilityMask.nii.gz'
    corticalthickness.inputs.brain_template = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0.nii.gz'
    corticalthickness.inputs.segmentation_priors = [
        '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors1.nii.gz',
        '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors2.nii.gz',
        '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors3.nii.gz',
        '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors4.nii.gz',
        '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors5.nii.gz',
        '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors6.nii.gz'
    ]
    corticalthickness.inputs.extraction_registration_mask = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0_BrainCerebellumExtractionMask.nii.gz'
    corticalthickness.inputs.t1_registration_template = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0_BrainCerebellum.nii.gz'

    # Creating visualisations for quality control
    converter = pe.Node(interface=ConvertScalarImageToRGB(), name='converter')
    converter.inputs.dimension = 3
    converter.inputs.colormap = 'cool'
    converter.inputs.minimum_input = 0
    converter.inputs.maximum_input = 5

    mosaic_slicer = pe.Node(interface=CreateTiledMosaic(),
                            name='mosaic_slicer')
    mosaic_slicer.inputs.pad_or_crop = 'mask'
    mosaic_slicer.inputs.slices = '[4 ,mask , mask]'
    mosaic_slicer.inputs.direction = 1
    mosaic_slicer.inputs.alpha_value = 0.5

    # Getting GM density images
    gm_density = pe.Node(interface=GM_DENSITY(), name='gm_density')
    sl = pe.Node(interface=Select(index=1), name='sl')

    # Applying transformation
    at = pe.Node(interface=ApplyTransforms(), name='at')
    at.inputs.dimension = 3
    at.inputs.reference_image = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0_BrainCerebellum.nii.gz'
    at.inputs.interpolation = 'Linear'
    at.inputs.default_value = 0
    at.inputs.invert_transform_flags = False

    # Multiplying the normalized image with Jacobian
    multiply_images = pe.Node(interface=MultiplyImages(dimension=3),
                              name='multiply_images')

    # Naming the output of multiply_image
    def generate_filename(subject_id):
        return subject_id + '_multiplied.nii.gz'

    generate_filename = pe.Node(interface=Function(
        input_names=["subject_id"],
        output_names=["out_filename"],
        function=generate_filename),
                                name='generate_filename')

    #====================================
    # Setting up the workflow
    antsthickness = pe.Workflow(name='antsthickness')

    antsthickness.connect(infosource, 'subject_id', selectfiles, 'subject_id')
    antsthickness.connect(selectfiles, 'T1', T1_rigid_quickSyN, 'moving_image')
    antsthickness.connect(infosource, 'subject_id', T1_rigid_quickSyN,
                          'output_prefix')
    antsthickness.connect(T1_rigid_quickSyN, 'warped_image', corticalthickness,
                          'anatomical_image')
    antsthickness.connect(infosource, 'subject_id', corticalthickness,
                          'out_prefix')
    antsthickness.connect(corticalthickness, 'CorticalThickness', converter,
                          'input_image')
    antsthickness.connect(converter, 'output_image', mosaic_slicer,
                          'rgb_image')
    antsthickness.connect(corticalthickness, 'BrainSegmentationN4',
                          mosaic_slicer, 'input_image')
    antsthickness.connect(corticalthickness, 'BrainExtractionMask',
                          mosaic_slicer, 'mask_image')

    antsthickness.connect(corticalthickness, 'BrainSegmentationN4', gm_density,
                          'in_file')
    antsthickness.connect(corticalthickness, 'BrainSegmentationPosteriors', sl,
                          'inlist')
    antsthickness.connect(sl, 'out', gm_density, 'mask_file')
    antsthickness.connect(corticalthickness, 'SubjectToTemplate1Warp', at,
                          'transforms')
    antsthickness.connect(gm_density, 'out_file', at, 'input_image')
    antsthickness.connect(corticalthickness, 'SubjectToTemplateLogJacobian',
                          multiply_images, 'second_input')
    antsthickness.connect(corticalthickness,
                          'CorticalThicknessNormedToTemplate', multiply_images,
                          'first_input')
    antsthickness.connect(infosource, 'subject_id', generate_filename,
                          'subject_id')
    antsthickness.connect(generate_filename, 'out_filename', multiply_images,
                          'output_product_image')

    #====================================
    # Running the workflow
    antsthickness.base_dir = os.path.abspath(directory)
    antsthickness.write_graph()
    antsthickness.run('PBSGraph')