Exemplo n.º 1
0
def average_images(img_list, idx, tname):
    """

    Parameters
    ----------
    img_list: list of strings
    idx: numpy array
    tname: string

    Returns
    -------

    """

    # Init subsampled image list
    sample_imgs = []

    # Save image list to sidecar .txt file
    tname_txt = tname.replace('nii.gz','txt')
    with open(tname_txt, 'w') as fd:
        for ii in idx:
            sample_imgs.append(img_list[ii])
            fd.write(img_list[ii] + '\n')

    # Call ANTs AverageImages command using nipype
    avg = AverageImages()
    avg.inputs.dimension = 3
    avg.inputs.output_average_image = tname
    avg.inputs.normalize = False
    avg.inputs.images = sample_imgs
    avg.run()

    return img_list
Exemplo n.º 2
0
def average_images(img_list, idx, tname):
    """

    Parameters
    ----------
    img_list: list of strings
    idx: numpy array
    tname: string

    Returns
    -------

    """

    # Init subsampled image list
    sample_imgs = []

    # Save image list to sidecar .txt file
    tname_txt = tname.replace('nii.gz', 'txt')
    with open(tname_txt, 'w') as fd:
        for ii in idx:
            sample_imgs.append(img_list[ii])
            fd.write(img_list[ii] + '\n')

    # Call ANTs AverageImages command using nipype
    avg = AverageImages()
    avg.inputs.dimension = 3
    avg.inputs.output_average_image = tname
    avg.inputs.normalize = False
    avg.inputs.images = sample_imgs
    avg.run()

    return img_list
def antsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix=''):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image :
           inputspec.ListOfPassiveImagesDictionaries :
           inputspec.interpolationMapping :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates :
    """
    TemplateBuildSingleIterationWF = pe.Workflow(
        name='antsRegistrationTemplateBuildSingleIterationWF_' +
        str(iterationPhasePrefix))

    inputSpec = pe.Node(interface=util.IdentityInterface(fields=[
        'ListOfImagesDictionaries', 'registrationImageTypes',
        'interpolationMapping', 'fixed_image'
    ]),
                        run_without_submitting=True,
                        name='inputspec')
    ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: TODO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: TODO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['template', 'transforms_list', 'passive_deformed_templates']),
                         run_without_submitting=True,
                         name='outputspec')

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(interface=Registration(),
                           name='BeginANTS',
                           iterfield=['moving_image'])
    BeginANTS.inputs.dimension = 3
    BeginANTS.inputs.output_transform_prefix = str(
        iterationPhasePrefix) + '_tfm'
    BeginANTS.inputs.transforms = ["Affine", "SyN"]
    BeginANTS.inputs.transform_parameters = [[0.9], [0.25, 3.0, 0.0]]
    BeginANTS.inputs.metric = ['Mattes', 'CC']
    BeginANTS.inputs.metric_weight = [1.0, 1.0]
    BeginANTS.inputs.radius_or_number_of_bins = [32, 5]
    BeginANTS.inputs.number_of_iterations = [[1000, 1000, 1000], [50, 35, 15]]
    BeginANTS.inputs.use_histogram_matching = [True, True]
    BeginANTS.inputs.use_estimate_learning_rate_once = [False, False]
    BeginANTS.inputs.shrink_factors = [[3, 2, 1], [3, 2, 1]]
    BeginANTS.inputs.smoothing_sigmas = [[3, 2, 0], [3, 2, 0]]

    GetMovingImagesNode = pe.Node(interface=util.Function(
        function=GetMovingImages,
        input_names=[
            'ListOfImagesDictionaries', 'registrationImageTypes',
            'interpolationMapping'
        ],
        output_names=['moving_images', 'moving_interpolation_type']),
                                  run_without_submitting=True,
                                  name='99_GetMovingImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetMovingImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetMovingImagesNode,
                                           'registrationImageTypes')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           GetMovingImagesNode,
                                           'interpolationMapping')

    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', BeginANTS,
                                           'moving_image')
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_interpolation_type',
                                           BeginANTS, 'interpolation')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS,
                                           'fixed_image')

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=['transforms', 'invert_transform_flags', 'input_image'],
        name='wimtdeformed')
    wimtdeformed.inputs.interpolation = 'Linear'
    wimtdeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms',
                                           wimtdeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags',
                                           wimtdeformed,
                                           'invert_transform_flags')
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', wimtdeformed,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image',
                                           wimtdeformed, 'reference_image')

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(),
                                name='AvgDeformedImages')
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = str(
        iterationPhasePrefix) + '.nii.gz'
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image",
                                           AvgDeformedImages, 'images')

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(interface=AverageAffineTransform(),
                                 name='AvgAffineTransform')
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str(
        iterationPhasePrefix) + '_Affine.mat'

    SplitAffineAndWarpsNode = pe.Node(interface=util.Function(
        function=SplitAffineAndWarpComponents,
        input_names=['list_of_transforms_lists'],
        output_names=['affine_component_list', 'warp_component_list']),
                                      run_without_submitting=True,
                                      name='99_SplitAffineAndWarpsNode')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms',
                                           SplitAffineAndWarpsNode,
                                           'list_of_transforms_lists')
    TemplateBuildSingleIterationWF.connect(SplitAffineAndWarpsNode,
                                           'affine_component_list',
                                           AvgAffineTransform, 'transforms')

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages')
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = str(
        iterationPhasePrefix) + 'warp.nii.gz'
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(SplitAffineAndWarpsNode,
                                           'warp_component_list',
                                           AvgWarpImages, 'images')

    ## Now average the images together
    ## TODO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(interface=MultiplyImages(),
                                    name='GradientStepWarpImage')
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str(
        iterationPhasePrefix) + '_warp.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgWarpImages,
                                           'output_average_image',
                                           GradientStepWarpImage,
                                           'first_input')

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(interface=ApplyTransforms(),
                                  name='UpdateTemplateShape')
    UpdateTemplateShape.inputs.invert_transform_flags = [True]
    UpdateTemplateShape.inputs.interpolation = 'Linear'
    UpdateTemplateShape.default_value = 0

    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           UpdateTemplateShape,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect([
        (AvgAffineTransform, UpdateTemplateShape,
         [(('affine_transform', makeListOfOneElement), 'transforms')]),
    ])
    TemplateBuildSingleIterationWF.connect(GradientStepWarpImage,
                                           'output_product_image',
                                           UpdateTemplateShape, 'input_image')

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=MakeTransformListWithGradientWarps,
            input_names=['averageAffineTranform', 'gradientStepWarp'],
            output_names=['TransformListWithGradientWarps']),
        run_without_submitting=True,
        name='99_MakeTransformListWithGradientWarps')
    ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform, 'affine_transform',
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'averageAffineTranform')
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape, 'output_image',
        ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp')

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate')
    ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear'
    ReshapeAverageImageWithShapeUpdate.default_value = 0
    ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate,
        'transforms')
    TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate,
                                           'output_image', outputSpec,
                                           'template')

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(function=FlattenTransformAndImagesList,
                 input_names=[
                     'ListOfPassiveImagesDictionaries', 'transforms',
                     'invert_transform_flags', 'interpolationMapping'
                 ],
                 output_names=[
                     'flattened_images', 'flattened_transforms',
                     'flattened_invert_transform_flags',
                     'flattened_image_nametypes',
                     'flattened_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList")

    GetPassiveImagesNode = pe.Node(interface=util.Function(
        function=GetPassiveImages,
        input_names=['ListOfImagesDictionaries', 'registrationImageTypes'],
        output_names=['ListOfPassiveImagesDictionaries']),
                                   run_without_submitting=True,
                                   name='99_GetPassiveImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetPassiveImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetPassiveImagesNode,
                                           'registrationImageTypes')

    TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode,
                                           'ListOfPassiveImagesDictionaries',
                                           FlattenTransformAndImagesListNode,
                                           'ListOfPassiveImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           FlattenTransformAndImagesListNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms',
                                           FlattenTransformAndImagesListNode,
                                           'transforms')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags',
                                           FlattenTransformAndImagesListNode,
                                           'invert_transform_flags')
    wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(),
                                     iterfield=[
                                         'transforms',
                                         'invert_transform_flags',
                                         'input_image', 'interpolation'
                                     ],
                                     name='wimtPassivedeformed')
    wimtPassivedeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           wimtPassivedeformed,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_interpolation_type',
                                           wimtPassivedeformed,
                                           'interpolation')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_images',
                                           wimtPassivedeformed, 'input_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_transforms',
                                           wimtPassivedeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_invert_transform_flags',
                                           wimtPassivedeformed,
                                           'invert_transform_flags')

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(function=RenestDeformedPassiveImages,
                 input_names=[
                     'deformedPassiveImages', 'flattened_image_nametypes',
                     'interpolationMapping'
                 ],
                 output_names=[
                     'nested_imagetype_list', 'outputAverageImageName_list',
                     'image_type_list', 'nested_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages")
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           RenestDeformedPassiveImagesNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image',
                                           RenestDeformedPassiveImagesNode,
                                           'deformedPassiveImages')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_image_nametypes',
                                           RenestDeformedPassiveImagesNode,
                                           'flattened_image_nametypes')
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=['images', 'output_average_image'],
        name='AvgDeformedPassiveImages')
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "nested_imagetype_list",
                                           AvgDeformedPassiveImages, 'images')
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "outputAverageImageName_list",
                                           AvgDeformedPassiveImages,
                                           'output_average_image')

    ## -- TODO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=[
            'input_image', 'reference_image', 'output_image', 'interpolation'
        ],
        name='ReshapeAveragePassiveImageWithShapeUpdate')
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'nested_interpolation_type',
        ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation')
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'outputAverageImageName_list',
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'input_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps',
        ReshapeAveragePassiveImageWithShapeUpdate, 'transforms')
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec,
        'passive_deformed_templates')

    return TemplateBuildSingleIterationWF
Exemplo n.º 4
0
def BAWantsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix,
                                                      CLUSTER_QUEUE,
                                                      CLUSTER_QUEUE_LONG):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image :
           inputspec.ListOfPassiveImagesDictionaries :
           inputspec.interpolationMapping :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates :
    """
    TemplateBuildSingleIterationWF = pe.Workflow(
        name='antsRegistrationTemplateBuildSingleIterationWF_' +
        str(iterationPhasePrefix))

    inputSpec = pe.Node(
        interface=util.IdentityInterface(fields=[
            'ListOfImagesDictionaries',
            'registrationImageTypes',
            # 'maskRegistrationImageType',
            'interpolationMapping',
            'fixed_image'
        ]),
        run_without_submitting=True,
        name='inputspec')
    ## HACK: TODO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0
    ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: TODO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: TODO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['template', 'transforms_list', 'passive_deformed_templates']),
                         run_without_submitting=True,
                         name='outputspec')

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(interface=Registration(),
                           name='BeginANTS',
                           iterfield=['moving_image'])
    # SEE template.py many_cpu_BeginANTS_options_dictionary = {'qsub_args': modify_qsub_args(CLUSTER_QUEUE,4,2,8), 'overwrite': True}
    ## This is set in the template.py file BeginANTS.plugin_args = BeginANTS_cpu_sge_options_dictionary
    CommonANTsRegistrationSettings(
        antsRegistrationNode=BeginANTS,
        registrationTypeDescription="SixStageAntsRegistrationT1Only",
        output_transform_prefix=str(iterationPhasePrefix) + '_tfm',
        output_warped_image='atlas2subject.nii.gz',
        output_inverse_warped_image='subject2atlas.nii.gz',
        save_state='SavedantsRegistrationNodeSyNState.h5',
        invert_initial_moving_transform=False,
        initial_moving_transform=None)

    GetMovingImagesNode = pe.Node(interface=util.Function(
        function=GetMovingImages,
        input_names=[
            'ListOfImagesDictionaries', 'registrationImageTypes',
            'interpolationMapping'
        ],
        output_names=['moving_images', 'moving_interpolation_type']),
                                  run_without_submitting=True,
                                  name='99_GetMovingImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetMovingImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetMovingImagesNode,
                                           'registrationImageTypes')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           GetMovingImagesNode,
                                           'interpolationMapping')

    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', BeginANTS,
                                           'moving_image')
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_interpolation_type',
                                           BeginANTS, 'interpolation')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS,
                                           'fixed_image')

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=['transforms', 'input_image'],
        # iterfield=['transforms', 'invert_transform_flags', 'input_image'],
        name='wimtdeformed')
    wimtdeformed.inputs.interpolation = 'Linear'
    wimtdeformed.default_value = 0
    # HACK: Should try using forward_composite_transform
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           wimtdeformed, 'transforms')
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags')
    ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms
    # wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False]
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', wimtdeformed,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image',
                                           wimtdeformed, 'reference_image')

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(),
                                name='AvgDeformedImages')
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = str(
        iterationPhasePrefix) + '.nii.gz'
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image",
                                           AvgDeformedImages, 'images')

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(interface=AverageAffineTransform(),
                                 name='AvgAffineTransform')
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str(
        iterationPhasePrefix) + '_Affine.h5'

    SplitCompositeTransform = pe.MapNode(interface=util.Function(
        function=SplitCompositeToComponentTransforms,
        input_names=['transformFilename'],
        output_names=['affine_component_list', 'warp_component_list']),
                                         iterfield=['transformFilename'],
                                         run_without_submitting=True,
                                         name='99_SplitCompositeTransform')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           SplitCompositeTransform,
                                           'transformFilename')
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'transformFilename')
    TemplateBuildSingleIterationWF.connect(SplitCompositeTransform,
                                           'affine_component_list',
                                           AvgAffineTransform, 'transforms')

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages')
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = str(
        iterationPhasePrefix) + 'warp.nii.gz'
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(SplitCompositeTransform,
                                           'warp_component_list',
                                           AvgWarpImages, 'images')

    ## Now average the images together
    ## TODO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(interface=MultiplyImages(),
                                    name='GradientStepWarpImage')
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str(
        iterationPhasePrefix) + '_warp.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgWarpImages,
                                           'output_average_image',
                                           GradientStepWarpImage,
                                           'first_input')

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(interface=ApplyTransforms(),
                                  name='UpdateTemplateShape')
    UpdateTemplateShape.inputs.invert_transform_flags = [True]
    UpdateTemplateShape.inputs.interpolation = 'Linear'
    UpdateTemplateShape.default_value = 0

    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           UpdateTemplateShape,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect([
        (AvgAffineTransform, UpdateTemplateShape,
         [(('affine_transform', makeListOfOneElement), 'transforms')]),
    ])
    TemplateBuildSingleIterationWF.connect(GradientStepWarpImage,
                                           'output_product_image',
                                           UpdateTemplateShape, 'input_image')

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=MakeTransformListWithGradientWarps,
            input_names=['averageAffineTranform', 'gradientStepWarp'],
            output_names=['TransformListWithGradientWarps']),
        run_without_submitting=True,
        name='99_MakeTransformListWithGradientWarps')
    ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform, 'affine_transform',
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'averageAffineTranform')
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape, 'output_image',
        ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp')

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate')
    ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear'
    ReshapeAverageImageWithShapeUpdate.default_value = 0
    ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate,
        'transforms')
    TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate,
                                           'output_image', outputSpec,
                                           'template')

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(function=FlattenTransformAndImagesList,
                 input_names=[
                     'ListOfPassiveImagesDictionaries', 'transforms',
                     'interpolationMapping', 'invert_transform_flags'
                 ],
                 output_names=[
                     'flattened_images', 'flattened_transforms',
                     'flattened_invert_transform_flags',
                     'flattened_image_nametypes',
                     'flattened_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList")

    GetPassiveImagesNode = pe.Node(interface=util.Function(
        function=GetPassiveImages,
        input_names=['ListOfImagesDictionaries', 'registrationImageTypes'],
        output_names=['ListOfPassiveImagesDictionaries']),
                                   run_without_submitting=True,
                                   name='99_GetPassiveImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetPassiveImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetPassiveImagesNode,
                                           'registrationImageTypes')

    TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode,
                                           'ListOfPassiveImagesDictionaries',
                                           FlattenTransformAndImagesListNode,
                                           'ListOfPassiveImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           FlattenTransformAndImagesListNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           FlattenTransformAndImagesListNode,
                                           'transforms')
    ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False]
    ## TODO: Please check of invert_transform_flags has a fixed number.
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags')
    wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(),
                                     iterfield=[
                                         'transforms',
                                         'invert_transform_flags',
                                         'input_image', 'interpolation'
                                     ],
                                     name='wimtPassivedeformed')
    wimtPassivedeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           wimtPassivedeformed,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_interpolation_type',
                                           wimtPassivedeformed,
                                           'interpolation')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_images',
                                           wimtPassivedeformed, 'input_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_transforms',
                                           wimtPassivedeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_invert_transform_flags',
                                           wimtPassivedeformed,
                                           'invert_transform_flags')

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(function=RenestDeformedPassiveImages,
                 input_names=[
                     'deformedPassiveImages', 'flattened_image_nametypes',
                     'interpolationMapping'
                 ],
                 output_names=[
                     'nested_imagetype_list', 'outputAverageImageName_list',
                     'image_type_list', 'nested_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages")
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           RenestDeformedPassiveImagesNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image',
                                           RenestDeformedPassiveImagesNode,
                                           'deformedPassiveImages')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_image_nametypes',
                                           RenestDeformedPassiveImagesNode,
                                           'flattened_image_nametypes')
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=['images', 'output_average_image'],
        name='AvgDeformedPassiveImages')
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "nested_imagetype_list",
                                           AvgDeformedPassiveImages, 'images')
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "outputAverageImageName_list",
                                           AvgDeformedPassiveImages,
                                           'output_average_image')

    ## -- TODO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=[
            'input_image', 'reference_image', 'output_image', 'interpolation'
        ],
        name='ReshapeAveragePassiveImageWithShapeUpdate')
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'nested_interpolation_type',
        ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation')
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'outputAverageImageName_list',
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'input_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps',
        ReshapeAveragePassiveImageWithShapeUpdate, 'transforms')
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec,
        'passive_deformed_templates')

    return TemplateBuildSingleIterationWF
Exemplo n.º 5
0
def BAWantsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix=''):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image :
           inputspec.ListOfPassiveImagesDictionaries :
           inputspec.interpolationMapping :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates :
    """
    TemplateBuildSingleIterationWF = pe.Workflow(
        name='antsRegistrationTemplateBuildSingleIterationWF_' +
        str(iterationPhasePrefix))

    inputSpec = pe.Node(
        interface=util.IdentityInterface(fields=[
            'ListOfImagesDictionaries',
            'registrationImageTypes',
            #'maskRegistrationImageType',
            'interpolationMapping',
            'fixed_image'
        ]),
        run_without_submitting=True,
        name='inputspec')
    ## HACK: TODO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0
    ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: TODO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: TODO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['template', 'transforms_list', 'passive_deformed_templates']),
                         run_without_submitting=True,
                         name='outputspec')

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(interface=Registration(),
                           name='BeginANTS',
                           iterfield=['moving_image'])
    BeginANTS.inputs.dimension = 3
    """ This is the recommended set of parameters from the ANTS developers """
    BeginANTS.inputs.output_transform_prefix = str(
        iterationPhasePrefix) + '_tfm'
    BeginANTS.inputs.transforms = ["Rigid", "Affine", "SyN", "SyN", "SyN"]
    BeginANTS.inputs.transform_parameters = [[0.1], [0.1], [0.1, 3.0, 0.0],
                                             [0.1, 3.0, 0.0], [0.1, 3.0, 0.0]]
    BeginANTS.inputs.metric = ['MI', 'MI', 'CC', 'CC', 'CC']
    BeginANTS.inputs.sampling_strategy = [
        'Regular', 'Regular', None, None, None
    ]
    BeginANTS.inputs.sampling_percentage = [0.27, 0.27, 1.0, 1.0, 1.0]
    BeginANTS.inputs.metric_weight = [1.0, 1.0, 1.0, 1.0, 1.0]
    BeginANTS.inputs.radius_or_number_of_bins = [32, 32, 4, 4, 4]
    BeginANTS.inputs.number_of_iterations = [[1000, 1000, 1000, 1000],
                                             [1000, 1000, 1000, 1000],
                                             [1000, 250], [140], [25]]
    BeginANTS.inputs.convergence_threshold = [5e-8, 5e-8, 5e-7, 5e-6, 5e-5]
    BeginANTS.inputs.convergence_window_size = [10, 10, 10, 10, 10]
    BeginANTS.inputs.use_histogram_matching = [True, True, True, True, True]
    BeginANTS.inputs.shrink_factors = [[8, 4, 2, 1], [8, 4, 2, 1], [8, 4], [2],
                                       [1]]
    BeginANTS.inputs.smoothing_sigmas = [[3, 2, 1, 0], [3, 2, 1, 0], [3, 2],
                                         [1], [0]]
    BeginANTS.inputs.sigma_units = ["vox", "vox", "vox", "vox", "vox"]
    BeginANTS.inputs.use_estimate_learning_rate_once = [
        False, False, False, False, False
    ]
    BeginANTS.inputs.write_composite_transform = True
    BeginANTS.inputs.collapse_output_transforms = False
    BeginANTS.inputs.initialize_transforms_per_stage = True
    BeginANTS.inputs.winsorize_lower_quantile = 0.01
    BeginANTS.inputs.winsorize_upper_quantile = 0.99
    BeginANTS.inputs.output_warped_image = 'atlas2subject.nii.gz'
    BeginANTS.inputs.output_inverse_warped_image = 'subject2atlas.nii.gz'
    BeginANTS.inputs.save_state = 'SavedBeginANTSSyNState.h5'
    BeginANTS.inputs.float = True

    GetMovingImagesNode = pe.Node(interface=util.Function(
        function=GetMovingImages,
        input_names=[
            'ListOfImagesDictionaries', 'registrationImageTypes',
            'interpolationMapping'
        ],
        output_names=['moving_images', 'moving_interpolation_type']),
                                  run_without_submitting=True,
                                  name='99_GetMovingImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetMovingImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetMovingImagesNode,
                                           'registrationImageTypes')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           GetMovingImagesNode,
                                           'interpolationMapping')

    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', BeginANTS,
                                           'moving_image')
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_interpolation_type',
                                           BeginANTS, 'interpolation')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS,
                                           'fixed_image')

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=['transforms', 'input_image'],
        #iterfield=['transforms', 'invert_transform_flags', 'input_image'],
        name='wimtdeformed')
    wimtdeformed.inputs.interpolation = 'Linear'
    wimtdeformed.default_value = 0
    # HACK: Should try using forward_composite_transform
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           wimtdeformed, 'transforms')
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags')
    ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms
    #wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False]
    TemplateBuildSingleIterationWF.connect(GetMovingImagesNode,
                                           'moving_images', wimtdeformed,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image',
                                           wimtdeformed, 'reference_image')

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(),
                                name='AvgDeformedImages')
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = str(
        iterationPhasePrefix) + '.nii.gz'
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image",
                                           AvgDeformedImages, 'images')

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(interface=AverageAffineTransform(),
                                 name='AvgAffineTransform')
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str(
        iterationPhasePrefix) + '_Affine.h5'

    SplitCompositeTransform = pe.MapNode(
        interface=util.Function(
            function=SplitCompositeToComponentTransforms,
            input_names=['composite_transform_as_list'],
            output_names=['affine_component_list', 'warp_component_list']),
        iterfield=['composite_transform_as_list'],
        run_without_submitting=True,
        name='99_SplitCompositeTransform')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           SplitCompositeTransform,
                                           'composite_transform_as_list')
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'composite_transform_as_list')
    TemplateBuildSingleIterationWF.connect(SplitCompositeTransform,
                                           'affine_component_list',
                                           AvgAffineTransform, 'transforms')

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages')
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = str(
        iterationPhasePrefix) + 'warp.nii.gz'
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(SplitCompositeTransform,
                                           'warp_component_list',
                                           AvgWarpImages, 'images')

    ## Now average the images together
    ## TODO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(interface=MultiplyImages(),
                                    name='GradientStepWarpImage')
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str(
        iterationPhasePrefix) + '_warp.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgWarpImages,
                                           'output_average_image',
                                           GradientStepWarpImage,
                                           'first_input')

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(interface=ApplyTransforms(),
                                  name='UpdateTemplateShape')
    UpdateTemplateShape.inputs.invert_transform_flags = [True]
    UpdateTemplateShape.inputs.interpolation = 'Linear'
    UpdateTemplateShape.default_value = 0

    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           UpdateTemplateShape,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect([
        (AvgAffineTransform, UpdateTemplateShape,
         [(('affine_transform', makeListOfOneElement), 'transforms')]),
    ])
    TemplateBuildSingleIterationWF.connect(GradientStepWarpImage,
                                           'output_product_image',
                                           UpdateTemplateShape, 'input_image')

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=MakeTransformListWithGradientWarps,
            input_names=['averageAffineTranform', 'gradientStepWarp'],
            output_names=['TransformListWithGradientWarps']),
        run_without_submitting=True,
        name='99_MakeTransformListWithGradientWarps')
    ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform, 'affine_transform',
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'averageAffineTranform')
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape, 'output_image',
        ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp')

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate')
    ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear'
    ReshapeAverageImageWithShapeUpdate.default_value = 0
    ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate,
        'transforms')
    TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate,
                                           'output_image', outputSpec,
                                           'template')

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(function=FlattenTransformAndImagesList,
                 input_names=[
                     'ListOfPassiveImagesDictionaries', 'transforms',
                     'interpolationMapping', 'invert_transform_flags'
                 ],
                 output_names=[
                     'flattened_images', 'flattened_transforms',
                     'flattened_invert_transform_flags',
                     'flattened_image_nametypes',
                     'flattened_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList")

    GetPassiveImagesNode = pe.Node(interface=util.Function(
        function=GetPassiveImages,
        input_names=['ListOfImagesDictionaries', 'registrationImageTypes'],
        output_names=['ListOfPassiveImagesDictionaries']),
                                   run_without_submitting=True,
                                   name='99_GetPassiveImagesNode')
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfImagesDictionaries',
                                           GetPassiveImagesNode,
                                           'ListOfImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes',
                                           GetPassiveImagesNode,
                                           'registrationImageTypes')

    TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode,
                                           'ListOfPassiveImagesDictionaries',
                                           FlattenTransformAndImagesListNode,
                                           'ListOfPassiveImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           FlattenTransformAndImagesListNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform',
                                           FlattenTransformAndImagesListNode,
                                           'transforms')
    ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False]
    ## TODO: Please check of invert_transform_flags has a fixed number.
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags')
    wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(),
                                     iterfield=[
                                         'transforms',
                                         'invert_transform_flags',
                                         'input_image', 'interpolation'
                                     ],
                                     name='wimtPassivedeformed')
    wimtPassivedeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           wimtPassivedeformed,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_interpolation_type',
                                           wimtPassivedeformed,
                                           'interpolation')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_images',
                                           wimtPassivedeformed, 'input_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_transforms',
                                           wimtPassivedeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_invert_transform_flags',
                                           wimtPassivedeformed,
                                           'invert_transform_flags')

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(function=RenestDeformedPassiveImages,
                 input_names=[
                     'deformedPassiveImages', 'flattened_image_nametypes',
                     'interpolationMapping'
                 ],
                 output_names=[
                     'nested_imagetype_list', 'outputAverageImageName_list',
                     'image_type_list', 'nested_interpolation_type'
                 ]),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages")
    TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping',
                                           RenestDeformedPassiveImagesNode,
                                           'interpolationMapping')
    TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image',
                                           RenestDeformedPassiveImagesNode,
                                           'deformedPassiveImages')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_image_nametypes',
                                           RenestDeformedPassiveImagesNode,
                                           'flattened_image_nametypes')
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=['images', 'output_average_image'],
        name='AvgDeformedPassiveImages')
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "nested_imagetype_list",
                                           AvgDeformedPassiveImages, 'images')
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "outputAverageImageName_list",
                                           AvgDeformedPassiveImages,
                                           'output_average_image')

    ## -- TODO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=[
            'input_image', 'reference_image', 'output_image', 'interpolation'
        ],
        name='ReshapeAveragePassiveImageWithShapeUpdate')
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [
        True, False, False, False, False
    ]
    ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'nested_interpolation_type',
        ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation')
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, 'outputAverageImageName_list',
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'input_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps',
        ReshapeAveragePassiveImageWithShapeUpdate, 'transforms')
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec,
        'passive_deformed_templates')

    return TemplateBuildSingleIterationWF
Exemplo n.º 6
0
def ANTSTemplateBuildSingleIterationWF(iterationPhasePrefix=''):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image : 
           inputspec.ListOfPassiveImagesDictionaries :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates : 
    """

    TemplateBuildSingleIterationWF = pe.Workflow(
        name='ANTSTemplateBuildSingleIterationWF_' +
        str(str(iterationPhasePrefix)))

    inputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['images', 'fixed_image', 'ListOfPassiveImagesDictionaries']),
                        run_without_submitting=True,
                        name='inputspec')
    ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: TODO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: TODO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(interface=util.IdentityInterface(
        fields=['template', 'transforms_list', 'passive_deformed_templates']),
                         run_without_submitting=True,
                         name='outputspec')

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(interface=ANTS(),
                           name='BeginANTS',
                           iterfield=['moving_image'])
    BeginANTS.inputs.dimension = 3
    BeginANTS.inputs.output_transform_prefix = str(
        iterationPhasePrefix) + '_tfm'
    BeginANTS.inputs.metric = ['CC']
    BeginANTS.inputs.metric_weight = [1.0]
    BeginANTS.inputs.radius = [5]
    BeginANTS.inputs.transformation_model = 'SyN'
    BeginANTS.inputs.gradient_step_length = 0.25
    BeginANTS.inputs.number_of_iterations = [50, 35, 15]
    BeginANTS.inputs.number_of_affine_iterations = [
        10000, 10000, 10000, 10000, 10000
    ]
    BeginANTS.inputs.use_histogram_matching = True
    BeginANTS.inputs.mi_option = [32, 16000]
    BeginANTS.inputs.regularization = 'Gauss'
    BeginANTS.inputs.regularization_gradient_field_sigma = 3
    BeginANTS.inputs.regularization_deformation_field_sigma = 0
    TemplateBuildSingleIterationWF.connect(inputSpec, 'images', BeginANTS,
                                           'moving_image')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS,
                                           'fixed_image')

    MakeTransformsLists = pe.Node(interface=util.Function(
        function=MakeListsOfTransformLists,
        input_names=['warpTransformList', 'AffineTransformList'],
        output_names=['out']),
                                  run_without_submitting=True,
                                  name='MakeTransformsLists')
    MakeTransformsLists.inputs.ignore_exception = True
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'warp_transform',
                                           MakeTransformsLists,
                                           'warpTransformList')
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'affine_transform',
                                           MakeTransformsLists,
                                           'AffineTransformList')

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=WarpImageMultiTransform(),
        iterfield=['transformation_series', 'input_image'],
        name='wimtdeformed')
    TemplateBuildSingleIterationWF.connect(inputSpec, 'images', wimtdeformed,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(MakeTransformsLists, 'out',
                                           wimtdeformed,
                                           'transformation_series')

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(),
                                name='AvgDeformedImages')
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = str(
        iterationPhasePrefix) + '.nii.gz'
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image",
                                           AvgDeformedImages, 'images')

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(interface=AverageAffineTransform(),
                                 name='AvgAffineTransform')
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str(
        iterationPhasePrefix) + '_Affine.mat'
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'affine_transform',
                                           AvgAffineTransform, 'transforms')

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages')
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = str(
        iterationPhasePrefix) + 'warp.nii.gz'
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(BeginANTS, 'warp_transform',
                                           AvgWarpImages, 'images')

    ## Now average the images together
    ## TODO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(interface=MultiplyImages(),
                                    name='GradientStepWarpImage')
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str(
        iterationPhasePrefix) + '_warp.nii.gz'
    TemplateBuildSingleIterationWF.connect(AvgWarpImages,
                                           'output_average_image',
                                           GradientStepWarpImage,
                                           'first_input')

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(interface=WarpImageMultiTransform(),
                                  name='UpdateTemplateShape')
    UpdateTemplateShape.inputs.invert_affine = [1]
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           UpdateTemplateShape,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(AvgAffineTransform,
                                           'affine_transform',
                                           UpdateTemplateShape,
                                           'transformation_series')
    TemplateBuildSingleIterationWF.connect(GradientStepWarpImage,
                                           'output_product_image',
                                           UpdateTemplateShape, 'input_image')

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=MakeTransformListWithGradientWarps,
            input_names=['averageAffineTranform', 'gradientStepWarp'],
            output_names=['TransformListWithGradientWarps']),
        run_without_submitting=True,
        name='MakeTransformListWithGradientWarps')
    ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform, 'affine_transform',
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'averageAffineTranform')
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape, 'output_image',
        ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp')

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=WarpImageMultiTransform(),
        name='ReshapeAverageImageWithShapeUpdate')
    ReshapeAverageImageWithShapeUpdate.inputs.invert_affine = [1]
    ReshapeAverageImageWithShapeUpdate.inputs.out_postfix = '_Reshaped'
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'input_image')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           ReshapeAverageImageWithShapeUpdate,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate,
        'transformation_series')
    TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate,
                                           'output_image', outputSpec,
                                           'template')

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(function=FlattenTransformAndImagesList,
                 input_names=[
                     'ListOfPassiveImagesDictionaries', 'transformation_series'
                 ],
                 output_names=[
                     'flattened_images', 'flattened_transforms',
                     'flattened_image_nametypes'
                 ]),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList")
    TemplateBuildSingleIterationWF.connect(inputSpec,
                                           'ListOfPassiveImagesDictionaries',
                                           FlattenTransformAndImagesListNode,
                                           'ListOfPassiveImagesDictionaries')
    TemplateBuildSingleIterationWF.connect(MakeTransformsLists, 'out',
                                           FlattenTransformAndImagesListNode,
                                           'transformation_series')
    wimtPassivedeformed = pe.MapNode(
        interface=WarpImageMultiTransform(),
        iterfield=['transformation_series', 'input_image'],
        name='wimtPassivedeformed')
    TemplateBuildSingleIterationWF.connect(AvgDeformedImages,
                                           'output_average_image',
                                           wimtPassivedeformed,
                                           'reference_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_images',
                                           wimtPassivedeformed, 'input_image')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_transforms',
                                           wimtPassivedeformed,
                                           'transformation_series')

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(
            function=RenestDeformedPassiveImages,
            input_names=['deformedPassiveImages', 'flattened_image_nametypes'],
            output_names=[
                'nested_imagetype_list', 'outputAverageImageName_list',
                'image_type_list'
            ]),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages")
    TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image',
                                           RenestDeformedPassiveImagesNode,
                                           'deformedPassiveImages')
    TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode,
                                           'flattened_image_nametypes',
                                           RenestDeformedPassiveImagesNode,
                                           'flattened_image_nametypes')
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=['images', 'output_average_image'],
        name='AvgDeformedPassiveImages')
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "nested_imagetype_list",
                                           AvgDeformedPassiveImages, 'images')
    TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode,
                                           "outputAverageImageName_list",
                                           AvgDeformedPassiveImages,
                                           'output_average_image')

    ## -- TODO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=WarpImageMultiTransform(),
        iterfield=['input_image', 'reference_image', 'out_postfix'],
        name='ReshapeAveragePassiveImageWithShapeUpdate')
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_affine = [1]
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode, "image_type_list",
        ReshapeAveragePassiveImageWithShapeUpdate, 'out_postfix')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'input_image')
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages, 'output_average_image',
        ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image')
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        'TransformListWithGradientWarps',
        ReshapeAveragePassiveImageWithShapeUpdate, 'transformation_series')
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec,
        'passive_deformed_templates')

    return TemplateBuildSingleIterationWF
def baw_ants_registration_template_build_single_iteration_wf(
    iterationPhasePrefix, CLUSTER_QUEUE, CLUSTER_QUEUE_LONG
):
    """

    Inputs::

           inputspec.images :
           inputspec.fixed_image :
           inputspec.ListOfPassiveImagesDictionaries :
           inputspec.interpolationMapping :

    Outputs::

           outputspec.template :
           outputspec.transforms_list :
           outputspec.passive_deformed_templates :
    """
    TemplateBuildSingleIterationWF = pe.Workflow(
        name="antsRegistrationTemplateBuildSingleIterationWF_"
        + str(iterationPhasePrefix)
    )

    inputSpec = pe.Node(
        interface=util.IdentityInterface(
            fields=[
                "ListOfImagesDictionaries",
                "registrationImageTypes",
                # 'maskRegistrationImageType',
                "interpolationMapping",
                "fixed_image",
            ]
        ),
        run_without_submitting=True,
        name="inputspec",
    )
    ## HACK: INFO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0
    ## HACK: INFO: Need to move all local functions to a common untility file, or at the top of the file so that
    ##             they do not change due to re-indenting.  Otherwise re-indenting for flow control will trigger
    ##             their hash to change.
    ## HACK: INFO: REMOVE 'transforms_list' it is not used.  That will change all the hashes
    ## HACK: INFO: Need to run all python files through the code beutifiers.  It has gotten pretty ugly.
    outputSpec = pe.Node(
        interface=util.IdentityInterface(
            fields=["template", "transforms_list", "passive_deformed_templates"]
        ),
        run_without_submitting=True,
        name="outputspec",
    )

    ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template
    BeginANTS = pe.MapNode(
        interface=Registration(), name="BeginANTS", iterfield=["moving_image"]
    )
    # SEE template.py many_cpu_BeginANTS_options_dictionary = {'qsub_args': modify_qsub_args(CLUSTER_QUEUE,4,2,8), 'overwrite': True}
    ## This is set in the template.py file BeginANTS.plugin_args = BeginANTS_cpu_sge_options_dictionary
    common_ants_registration_settings(
        antsRegistrationNode=BeginANTS,
        registrationTypeDescription="SixStageAntsRegistrationT1Only",
        output_transform_prefix=str(iterationPhasePrefix) + "_tfm",
        output_warped_image="atlas2subject.nii.gz",
        output_inverse_warped_image="subject2atlas.nii.gz",
        save_state="SavedantsRegistrationNodeSyNState.h5",
        invert_initial_moving_transform=False,
        initial_moving_transform=None,
    )

    GetMovingImagesNode = pe.Node(
        interface=util.Function(
            function=get_moving_images,
            input_names=[
                "ListOfImagesDictionaries",
                "registrationImageTypes",
                "interpolationMapping",
            ],
            output_names=["moving_images", "moving_interpolation_type"],
        ),
        run_without_submitting=True,
        name="99_GetMovingImagesNode",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "ListOfImagesDictionaries",
        GetMovingImagesNode,
        "ListOfImagesDictionaries",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "registrationImageTypes",
        GetMovingImagesNode,
        "registrationImageTypes",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec, "interpolationMapping", GetMovingImagesNode, "interpolationMapping"
    )

    TemplateBuildSingleIterationWF.connect(
        GetMovingImagesNode, "moving_images", BeginANTS, "moving_image"
    )
    TemplateBuildSingleIterationWF.connect(
        GetMovingImagesNode, "moving_interpolation_type", BeginANTS, "interpolation"
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec, "fixed_image", BeginANTS, "fixed_image"
    )

    ## Now warp all the input_images images
    wimtdeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=["transforms", "input_image"],
        # iterfield=['transforms', 'invert_transform_flags', 'input_image'],
        name="wimtdeformed",
    )
    wimtdeformed.inputs.interpolation = "Linear"
    wimtdeformed.default_value = 0
    # HACK: Should try using forward_composite_transform
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms')
    TemplateBuildSingleIterationWF.connect(
        BeginANTS, "composite_transform", wimtdeformed, "transforms"
    )
    ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags')
    ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms
    # wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False]
    TemplateBuildSingleIterationWF.connect(
        GetMovingImagesNode, "moving_images", wimtdeformed, "input_image"
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec, "fixed_image", wimtdeformed, "reference_image"
    )

    ##  Shape Update Next =====
    ## Now  Average All input_images deformed images together to create an updated template average
    AvgDeformedImages = pe.Node(interface=AverageImages(), name="AvgDeformedImages")
    AvgDeformedImages.inputs.dimension = 3
    AvgDeformedImages.inputs.output_average_image = (
        str(iterationPhasePrefix) + ".nii.gz"
    )
    AvgDeformedImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(
        wimtdeformed, "output_image", AvgDeformedImages, "images"
    )

    ## Now average all affine transforms together
    AvgAffineTransform = pe.Node(
        interface=AverageAffineTransform(), name="AvgAffineTransform"
    )
    AvgAffineTransform.inputs.dimension = 3
    AvgAffineTransform.inputs.output_affine_transform = (
        "Avererage_" + str(iterationPhasePrefix) + "_Affine.h5"
    )

    SplitCompositeTransform = pe.MapNode(
        interface=util.Function(
            function=split_composite_to_component_transform,
            input_names=["transformFilename"],
            output_names=["affine_component_list", "warp_component_list"],
        ),
        iterfield=["transformFilename"],
        run_without_submitting=True,
        name="99_SplitCompositeTransform",
    )
    TemplateBuildSingleIterationWF.connect(
        BeginANTS, "composite_transform", SplitCompositeTransform, "transformFilename"
    )
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'transformFilename')
    TemplateBuildSingleIterationWF.connect(
        SplitCompositeTransform,
        "affine_component_list",
        AvgAffineTransform,
        "transforms",
    )

    ## Now average the warp fields togther
    AvgWarpImages = pe.Node(interface=AverageImages(), name="AvgWarpImages")
    AvgWarpImages.inputs.dimension = 3
    AvgWarpImages.inputs.output_average_image = (
        str(iterationPhasePrefix) + "warp.nii.gz"
    )
    AvgWarpImages.inputs.normalize = True
    TemplateBuildSingleIterationWF.connect(
        SplitCompositeTransform, "warp_component_list", AvgWarpImages, "images"
    )

    ## Now average the images together
    ## INFO:  For now GradientStep is set to 0.25 as a hard coded default value.
    GradientStep = 0.25
    GradientStepWarpImage = pe.Node(
        interface=MultiplyImages(), name="GradientStepWarpImage"
    )
    GradientStepWarpImage.inputs.dimension = 3
    GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep
    GradientStepWarpImage.inputs.output_product_image = (
        "GradientStep0.25_" + str(iterationPhasePrefix) + "_warp.nii.gz"
    )
    TemplateBuildSingleIterationWF.connect(
        AvgWarpImages, "output_average_image", GradientStepWarpImage, "first_input"
    )

    ## Now create the new template shape based on the average of all deformed images
    UpdateTemplateShape = pe.Node(
        interface=ApplyTransforms(), name="UpdateTemplateShape"
    )
    UpdateTemplateShape.inputs.invert_transform_flags = [True]
    UpdateTemplateShape.inputs.interpolation = "Linear"
    UpdateTemplateShape.default_value = 0

    TemplateBuildSingleIterationWF.connect(
        AvgDeformedImages,
        "output_average_image",
        UpdateTemplateShape,
        "reference_image",
    )
    TemplateBuildSingleIterationWF.connect(
        [
            (
                AvgAffineTransform,
                UpdateTemplateShape,
                [(("affine_transform", make_list_of_one_element), "transforms")],
            )
        ]
    )
    TemplateBuildSingleIterationWF.connect(
        GradientStepWarpImage,
        "output_product_image",
        UpdateTemplateShape,
        "input_image",
    )

    ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node(
        interface=util.Function(
            function=make_transform_list_with_gradient_warps,
            input_names=["averageAffineTranform", "gradientStepWarp"],
            output_names=["TransformListWithGradientWarps"],
        ),
        run_without_submitting=True,
        name="99_MakeTransformListWithGradientWarps",
    )
    # ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True

    TemplateBuildSingleIterationWF.connect(
        AvgAffineTransform,
        "affine_transform",
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        "averageAffineTranform",
    )
    TemplateBuildSingleIterationWF.connect(
        UpdateTemplateShape,
        "output_image",
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        "gradientStepWarp",
    )

    ReshapeAverageImageWithShapeUpdate = pe.Node(
        interface=ApplyTransforms(), name="ReshapeAverageImageWithShapeUpdate"
    )
    ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [
        True,
        False,
        False,
        False,
        False,
    ]
    ReshapeAverageImageWithShapeUpdate.inputs.interpolation = "Linear"
    ReshapeAverageImageWithShapeUpdate.default_value = 0
    ReshapeAverageImageWithShapeUpdate.inputs.output_image = (
        "ReshapeAverageImageWithShapeUpdate.nii.gz"
    )
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedImages,
        "output_average_image",
        ReshapeAverageImageWithShapeUpdate,
        "input_image",
    )
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedImages,
        "output_average_image",
        ReshapeAverageImageWithShapeUpdate,
        "reference_image",
    )
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        "TransformListWithGradientWarps",
        ReshapeAverageImageWithShapeUpdate,
        "transforms",
    )
    TemplateBuildSingleIterationWF.connect(
        ReshapeAverageImageWithShapeUpdate, "output_image", outputSpec, "template"
    )

    ######
    ######
    ######  Process all the passive deformed images in a way similar to the main image used for registration
    ######
    ######
    ######
    ##############################################
    ## Now warp all the ListOfPassiveImagesDictionaries images
    FlattenTransformAndImagesListNode = pe.Node(
        Function(
            function=flatten_transform_and_images_list,
            input_names=[
                "ListOfPassiveImagesDictionaries",
                "transforms",
                "interpolationMapping",
                "invert_transform_flags",
            ],
            output_names=[
                "flattened_images",
                "flattened_transforms",
                "flattened_invert_transform_flags",
                "flattened_image_nametypes",
                "flattened_interpolation_type",
            ],
        ),
        run_without_submitting=True,
        name="99_FlattenTransformAndImagesList",
    )

    GetPassiveImagesNode = pe.Node(
        interface=util.Function(
            function=get_passive_images,
            input_names=["ListOfImagesDictionaries", "registrationImageTypes"],
            output_names=["ListOfPassiveImagesDictionaries"],
        ),
        run_without_submitting=True,
        name="99_GetPassiveImagesNode",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "ListOfImagesDictionaries",
        GetPassiveImagesNode,
        "ListOfImagesDictionaries",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "registrationImageTypes",
        GetPassiveImagesNode,
        "registrationImageTypes",
    )

    TemplateBuildSingleIterationWF.connect(
        GetPassiveImagesNode,
        "ListOfPassiveImagesDictionaries",
        FlattenTransformAndImagesListNode,
        "ListOfPassiveImagesDictionaries",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "interpolationMapping",
        FlattenTransformAndImagesListNode,
        "interpolationMapping",
    )
    TemplateBuildSingleIterationWF.connect(
        BeginANTS,
        "composite_transform",
        FlattenTransformAndImagesListNode,
        "transforms",
    )
    ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False]
    ## INFO: Please check of invert_transform_flags has a fixed number.
    ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags')
    wimtPassivedeformed = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=[
            "transforms",
            "invert_transform_flags",
            "input_image",
            "interpolation",
        ],
        name="wimtPassivedeformed",
    )
    wimtPassivedeformed.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedImages,
        "output_average_image",
        wimtPassivedeformed,
        "reference_image",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_interpolation_type",
        wimtPassivedeformed,
        "interpolation",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_images",
        wimtPassivedeformed,
        "input_image",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_transforms",
        wimtPassivedeformed,
        "transforms",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_invert_transform_flags",
        wimtPassivedeformed,
        "invert_transform_flags",
    )

    RenestDeformedPassiveImagesNode = pe.Node(
        Function(
            function=renest_deformed_passive_images,
            input_names=[
                "deformedPassiveImages",
                "flattened_image_nametypes",
                "interpolationMapping",
            ],
            output_names=[
                "nested_imagetype_list",
                "outputAverageImageName_list",
                "image_type_list",
                "nested_interpolation_type",
            ],
        ),
        run_without_submitting=True,
        name="99_RenestDeformedPassiveImages",
    )
    TemplateBuildSingleIterationWF.connect(
        inputSpec,
        "interpolationMapping",
        RenestDeformedPassiveImagesNode,
        "interpolationMapping",
    )
    TemplateBuildSingleIterationWF.connect(
        wimtPassivedeformed,
        "output_image",
        RenestDeformedPassiveImagesNode,
        "deformedPassiveImages",
    )
    TemplateBuildSingleIterationWF.connect(
        FlattenTransformAndImagesListNode,
        "flattened_image_nametypes",
        RenestDeformedPassiveImagesNode,
        "flattened_image_nametypes",
    )
    ## Now  Average All passive input_images deformed images together to create an updated template average
    AvgDeformedPassiveImages = pe.MapNode(
        interface=AverageImages(),
        iterfield=["images", "output_average_image"],
        name="AvgDeformedPassiveImages",
    )
    AvgDeformedPassiveImages.inputs.dimension = 3
    AvgDeformedPassiveImages.inputs.normalize = False
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode,
        "nested_imagetype_list",
        AvgDeformedPassiveImages,
        "images",
    )
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode,
        "outputAverageImageName_list",
        AvgDeformedPassiveImages,
        "output_average_image",
    )

    ## -- INFO:  Now neeed to reshape all the passive images as well
    ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode(
        interface=ApplyTransforms(),
        iterfield=["input_image", "reference_image", "output_image", "interpolation"],
        name="ReshapeAveragePassiveImageWithShapeUpdate",
    )
    ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [
        True,
        False,
        False,
        False,
        False,
    ]
    ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode,
        "nested_interpolation_type",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "interpolation",
    )
    TemplateBuildSingleIterationWF.connect(
        RenestDeformedPassiveImagesNode,
        "outputAverageImageName_list",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "output_image",
    )
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages,
        "output_average_image",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "input_image",
    )
    TemplateBuildSingleIterationWF.connect(
        AvgDeformedPassiveImages,
        "output_average_image",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "reference_image",
    )
    TemplateBuildSingleIterationWF.connect(
        ApplyInvAverageAndFourTimesGradientStepWarpImage,
        "TransformListWithGradientWarps",
        ReshapeAveragePassiveImageWithShapeUpdate,
        "transforms",
    )
    TemplateBuildSingleIterationWF.connect(
        ReshapeAveragePassiveImageWithShapeUpdate,
        "output_image",
        outputSpec,
        "passive_deformed_templates",
    )

    return TemplateBuildSingleIterationWF