def antsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix=''): """ Inputs:: inputspec.images : inputspec.fixed_image : inputspec.ListOfPassiveImagesDictionaries : inputspec.interpolationMapping : Outputs:: outputspec.template : outputspec.transforms_list : outputspec.passive_deformed_templates : """ TemplateBuildSingleIterationWF = pe.Workflow( name='antsRegistrationTemplateBuildSingleIterationWF_' + str(iterationPhasePrefix)) inputSpec = pe.Node(interface=util.IdentityInterface(fields=[ 'ListOfImagesDictionaries', 'registrationImageTypes', 'interpolationMapping', 'fixed_image' ]), run_without_submitting=True, name='inputspec') ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that ## they do not change due to re-indenting. Otherwise re-indenting for flow control will trigger ## their hash to change. ## HACK: TODO: REMOVE 'transforms_list' it is not used. That will change all the hashes ## HACK: TODO: Need to run all python files through the code beutifiers. It has gotten pretty ugly. outputSpec = pe.Node(interface=util.IdentityInterface( fields=['template', 'transforms_list', 'passive_deformed_templates']), run_without_submitting=True, name='outputspec') ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template BeginANTS = pe.MapNode(interface=Registration(), name='BeginANTS', iterfield=['moving_image']) BeginANTS.inputs.dimension = 3 BeginANTS.inputs.output_transform_prefix = str( iterationPhasePrefix) + '_tfm' BeginANTS.inputs.transforms = ["Affine", "SyN"] BeginANTS.inputs.transform_parameters = [[0.9], [0.25, 3.0, 0.0]] BeginANTS.inputs.metric = ['Mattes', 'CC'] BeginANTS.inputs.metric_weight = [1.0, 1.0] BeginANTS.inputs.radius_or_number_of_bins = [32, 5] BeginANTS.inputs.number_of_iterations = [[1000, 1000, 1000], [50, 35, 15]] BeginANTS.inputs.use_histogram_matching = [True, True] BeginANTS.inputs.use_estimate_learning_rate_once = [False, False] BeginANTS.inputs.shrink_factors = [[3, 2, 1], [3, 2, 1]] BeginANTS.inputs.smoothing_sigmas = [[3, 2, 0], [3, 2, 0]] GetMovingImagesNode = pe.Node(interface=util.Function( function=GetMovingImages, input_names=[ 'ListOfImagesDictionaries', 'registrationImageTypes', 'interpolationMapping' ], output_names=['moving_images', 'moving_interpolation_type']), run_without_submitting=True, name='99_GetMovingImagesNode') TemplateBuildSingleIterationWF.connect(inputSpec, 'ListOfImagesDictionaries', GetMovingImagesNode, 'ListOfImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes', GetMovingImagesNode, 'registrationImageTypes') TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', GetMovingImagesNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_images', BeginANTS, 'moving_image') TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_interpolation_type', BeginANTS, 'interpolation') TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS, 'fixed_image') ## Now warp all the input_images images wimtdeformed = pe.MapNode( interface=ApplyTransforms(), iterfield=['transforms', 'invert_transform_flags', 'input_image'], name='wimtdeformed') wimtdeformed.inputs.interpolation = 'Linear' wimtdeformed.default_value = 0 TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', wimtdeformed, 'transforms') TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags') TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_images', wimtdeformed, 'input_image') TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', wimtdeformed, 'reference_image') ## Shape Update Next ===== ## Now Average All input_images deformed images together to create an updated template average AvgDeformedImages = pe.Node(interface=AverageImages(), name='AvgDeformedImages') AvgDeformedImages.inputs.dimension = 3 AvgDeformedImages.inputs.output_average_image = str( iterationPhasePrefix) + '.nii.gz' AvgDeformedImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image", AvgDeformedImages, 'images') ## Now average all affine transforms together AvgAffineTransform = pe.Node(interface=AverageAffineTransform(), name='AvgAffineTransform') AvgAffineTransform.inputs.dimension = 3 AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str( iterationPhasePrefix) + '_Affine.mat' SplitAffineAndWarpsNode = pe.Node(interface=util.Function( function=SplitAffineAndWarpComponents, input_names=['list_of_transforms_lists'], output_names=['affine_component_list', 'warp_component_list']), run_without_submitting=True, name='99_SplitAffineAndWarpsNode') TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitAffineAndWarpsNode, 'list_of_transforms_lists') TemplateBuildSingleIterationWF.connect(SplitAffineAndWarpsNode, 'affine_component_list', AvgAffineTransform, 'transforms') ## Now average the warp fields togther AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages') AvgWarpImages.inputs.dimension = 3 AvgWarpImages.inputs.output_average_image = str( iterationPhasePrefix) + 'warp.nii.gz' AvgWarpImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect(SplitAffineAndWarpsNode, 'warp_component_list', AvgWarpImages, 'images') ## Now average the images together ## TODO: For now GradientStep is set to 0.25 as a hard coded default value. GradientStep = 0.25 GradientStepWarpImage = pe.Node(interface=MultiplyImages(), name='GradientStepWarpImage') GradientStepWarpImage.inputs.dimension = 3 GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str( iterationPhasePrefix) + '_warp.nii.gz' TemplateBuildSingleIterationWF.connect(AvgWarpImages, 'output_average_image', GradientStepWarpImage, 'first_input') ## Now create the new template shape based on the average of all deformed images UpdateTemplateShape = pe.Node(interface=ApplyTransforms(), name='UpdateTemplateShape') UpdateTemplateShape.inputs.invert_transform_flags = [True] UpdateTemplateShape.inputs.interpolation = 'Linear' UpdateTemplateShape.default_value = 0 TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', UpdateTemplateShape, 'reference_image') TemplateBuildSingleIterationWF.connect([ (AvgAffineTransform, UpdateTemplateShape, [(('affine_transform', makeListOfOneElement), 'transforms')]), ]) TemplateBuildSingleIterationWF.connect(GradientStepWarpImage, 'output_product_image', UpdateTemplateShape, 'input_image') ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node( interface=util.Function( function=MakeTransformListWithGradientWarps, input_names=['averageAffineTranform', 'gradientStepWarp'], output_names=['TransformListWithGradientWarps']), run_without_submitting=True, name='99_MakeTransformListWithGradientWarps') ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True TemplateBuildSingleIterationWF.connect( AvgAffineTransform, 'affine_transform', ApplyInvAverageAndFourTimesGradientStepWarpImage, 'averageAffineTranform') TemplateBuildSingleIterationWF.connect( UpdateTemplateShape, 'output_image', ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp') ReshapeAverageImageWithShapeUpdate = pe.Node( interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate') ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [ True, False, False, False, False ] ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear' ReshapeAverageImageWithShapeUpdate.default_value = 0 ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz' TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', ReshapeAverageImageWithShapeUpdate, 'input_image') TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', ReshapeAverageImageWithShapeUpdate, 'reference_image') TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, 'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate, 'transforms') TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate, 'output_image', outputSpec, 'template') ###### ###### ###### Process all the passive deformed images in a way similar to the main image used for registration ###### ###### ###### ############################################## ## Now warp all the ListOfPassiveImagesDictionaries images FlattenTransformAndImagesListNode = pe.Node( Function(function=FlattenTransformAndImagesList, input_names=[ 'ListOfPassiveImagesDictionaries', 'transforms', 'invert_transform_flags', 'interpolationMapping' ], output_names=[ 'flattened_images', 'flattened_transforms', 'flattened_invert_transform_flags', 'flattened_image_nametypes', 'flattened_interpolation_type' ]), run_without_submitting=True, name="99_FlattenTransformAndImagesList") GetPassiveImagesNode = pe.Node(interface=util.Function( function=GetPassiveImages, input_names=['ListOfImagesDictionaries', 'registrationImageTypes'], output_names=['ListOfPassiveImagesDictionaries']), run_without_submitting=True, name='99_GetPassiveImagesNode') TemplateBuildSingleIterationWF.connect(inputSpec, 'ListOfImagesDictionaries', GetPassiveImagesNode, 'ListOfImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes', GetPassiveImagesNode, 'registrationImageTypes') TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode, 'ListOfPassiveImagesDictionaries', FlattenTransformAndImagesListNode, 'ListOfPassiveImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', FlattenTransformAndImagesListNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', FlattenTransformAndImagesListNode, 'transforms') TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags') wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(), iterfield=[ 'transforms', 'invert_transform_flags', 'input_image', 'interpolation' ], name='wimtPassivedeformed') wimtPassivedeformed.default_value = 0 TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', wimtPassivedeformed, 'reference_image') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_interpolation_type', wimtPassivedeformed, 'interpolation') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_images', wimtPassivedeformed, 'input_image') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_transforms', wimtPassivedeformed, 'transforms') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_invert_transform_flags', wimtPassivedeformed, 'invert_transform_flags') RenestDeformedPassiveImagesNode = pe.Node( Function(function=RenestDeformedPassiveImages, input_names=[ 'deformedPassiveImages', 'flattened_image_nametypes', 'interpolationMapping' ], output_names=[ 'nested_imagetype_list', 'outputAverageImageName_list', 'image_type_list', 'nested_interpolation_type' ]), run_without_submitting=True, name="99_RenestDeformedPassiveImages") TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', RenestDeformedPassiveImagesNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image', RenestDeformedPassiveImagesNode, 'deformedPassiveImages') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_image_nametypes', RenestDeformedPassiveImagesNode, 'flattened_image_nametypes') ## Now Average All passive input_images deformed images together to create an updated template average AvgDeformedPassiveImages = pe.MapNode( interface=AverageImages(), iterfield=['images', 'output_average_image'], name='AvgDeformedPassiveImages') AvgDeformedPassiveImages.inputs.dimension = 3 AvgDeformedPassiveImages.inputs.normalize = False TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode, "nested_imagetype_list", AvgDeformedPassiveImages, 'images') TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode, "outputAverageImageName_list", AvgDeformedPassiveImages, 'output_average_image') ## -- TODO: Now neeed to reshape all the passive images as well ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode( interface=ApplyTransforms(), iterfield=[ 'input_image', 'reference_image', 'output_image', 'interpolation' ], name='ReshapeAveragePassiveImageWithShapeUpdate') ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [ True, False, False, False, False ] ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0 TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, 'nested_interpolation_type', ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation') TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, 'outputAverageImageName_list', ReshapeAveragePassiveImageWithShapeUpdate, 'output_image') TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, 'output_average_image', ReshapeAveragePassiveImageWithShapeUpdate, 'input_image') TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, 'output_average_image', ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image') TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, 'TransformListWithGradientWarps', ReshapeAveragePassiveImageWithShapeUpdate, 'transforms') TemplateBuildSingleIterationWF.connect( ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec, 'passive_deformed_templates') return TemplateBuildSingleIterationWF
def init_atropos_wf(name='atropos_wf', use_random_seed=True, omp_nthreads=None, mem_gb=3.0, padding=10, in_segmentation_model=list( ATROPOS_MODELS['T1w'].values())): """ Implements supersteps 6 and 7 of ``antsBrainExtraction.sh``, which refine the mask previously computed with the spatial normalization to the template. **Parameters** use_random_seed : bool Whether ATROPOS should generate a random seed based on the system's clock omp_nthreads : int Maximum number of threads an individual process may use mem_gb : float Estimated peak memory consumption of the most hungry nodes in the workflow padding : int Pad images with zeros before processing in_segmentation_model : tuple A k-means segmentation is run to find gray or white matter around the edge of the initial brain mask warped from the template. This produces a segmentation image with :math:`$K$` classes, ordered by mean intensity in increasing order. With this option, you can control :math:`$K$` and tell the script which classes represent CSF, gray and white matter. Format (K, csfLabel, gmLabel, wmLabel). Examples: - ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default) - ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1 - ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2 - ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3 name : str, optional Workflow name (default: atropos_wf) **Inputs** in_files :abbr:`INU (intensity non-uniformity)`-corrected files. in_mask Brain mask calculated previously **Outputs** out_mask Refined brain mask out_segm Output segmentation out_tpms Output :abbr:`TPMs (tissue probability maps)` """ wf = pe.Workflow(name) inputnode = pe.Node(niu.IdentityInterface( fields=['in_files', 'in_mask', 'in_mask_dilated']), name='inputnode') outputnode = pe.Node( niu.IdentityInterface(fields=['out_mask', 'out_segm', 'out_tpms']), name='outputnode') copy_xform = pe.Node( CopyXForm(fields=['out_mask', 'out_segm', 'out_tpms']), name='copy_xform', run_without_submitting=True, mem_gb=2.5) # Run atropos (core node) atropos = pe.Node(Atropos( dimension=3, initialization='KMeans', number_of_tissue_classes=in_segmentation_model[0], n_iterations=3, convergence_threshold=0.0, mrf_radius=[1, 1, 1], mrf_smoothing_factor=0.1, likelihood_model='Gaussian', use_random_seed=use_random_seed), name='01_atropos', n_procs=omp_nthreads, mem_gb=mem_gb) # massage outputs pad_segm = pe.Node(ImageMath(operation='PadImage', op2='%d' % padding), name='02_pad_segm') pad_mask = pe.Node(ImageMath(operation='PadImage', op2='%d' % padding), name='03_pad_mask') # Split segmentation in binary masks sel_labels = pe.Node(niu.Function( function=_select_labels, output_names=['out_wm', 'out_gm', 'out_csf']), name='04_sel_labels') sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:])) # Select largest components (GM, WM) # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM} get_wm = pe.Node(ImageMath(operation='GetLargestComponent'), name='05_get_wm') get_gm = pe.Node(ImageMath(operation='GetLargestComponent'), name='06_get_gm') # Fill holes and calculate intersection # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2 # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM} fill_gm = pe.Node(ImageMath(operation='FillHoles', op2='2'), name='07_fill_gm') mult_gm = pe.Node(MultiplyImages(dimension=3, output_product_image='08_mult_gm.nii.gz'), name='08_mult_gm') # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM} # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10 relabel_wm = pe.Node(MultiplyImages( dimension=3, second_input=in_segmentation_model[-1], output_product_image='09_relabel_wm.nii.gz'), name='09_relabel_wm') me_csf = pe.Node(ImageMath(operation='ME', op2='10'), name='10_me_csf') # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP} # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM} # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM} add_gm = pe.Node(ImageMath(operation='addtozero'), name='11_add_gm') relabel_gm = pe.Node(MultiplyImages( dimension=3, second_input=in_segmentation_model[-2], output_product_image='12_relabel_gm.nii.gz'), name='12_relabel_gm') add_gm_wm = pe.Node(ImageMath(operation='addtozero'), name='13_add_gm_wm') # Superstep 7 # Split segmentation in binary masks sel_labels2 = pe.Node(niu.Function(function=_select_labels, output_names=['out_gm', 'out_wm']), name='14_sel_labels2') sel_labels2.inputs.labels = in_segmentation_model[2:] # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP} add_7 = pe.Node(ImageMath(operation='addtozero'), name='15_add_7') # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2 me_7 = pe.Node(ImageMath(operation='ME', op2='2'), name='16_me_7') # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK} comp_7 = pe.Node(ImageMath(operation='GetLargestComponent'), name='17_comp_7') # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4 md_7 = pe.Node(ImageMath(operation='MD', op2='4'), name='18_md_7') # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2 fill_7 = pe.Node(ImageMath(operation='FillHoles', op2='2'), name='19_fill_7') # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \ # ${EXTRACTION_MASK_PRIOR_WARPED} add_7_2 = pe.Node(ImageMath(operation='addtozero'), name='20_add_7_2') # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5 md_7_2 = pe.Node(ImageMath(operation='MD', op2='5'), name='21_md_7_2') # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5 me_7_2 = pe.Node(ImageMath(operation='ME', op2='5'), name='22_me_7_2') # De-pad depad_mask = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding), name='23_depad_mask') depad_segm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding), name='24_depad_segm') depad_gm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding), name='25_depad_gm') depad_wm = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding), name='26_depad_wm') depad_csf = pe.Node(ImageMath(operation='PadImage', op2='-%d' % padding), name='27_depad_csf') msk_conform = pe.Node(niu.Function(function=_conform_mask), name='msk_conform') merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]), name='merge_tpms') wf.connect([ (inputnode, copy_xform, [(('in_files', _pop), 'hdr_file')]), (inputnode, pad_mask, [('in_mask', 'op1')]), (inputnode, atropos, [('in_files', 'intensity_images'), ('in_mask_dilated', 'mask_image')]), (inputnode, msk_conform, [(('in_files', _pop), 'in_reference')]), (atropos, pad_segm, [('classified_image', 'op1')]), (pad_segm, sel_labels, [('output_image', 'in_segm')]), (sel_labels, get_wm, [('out_wm', 'op1')]), (sel_labels, get_gm, [('out_gm', 'op1')]), (get_gm, fill_gm, [('output_image', 'op1')]), (get_gm, mult_gm, [('output_image', 'first_input')]), (fill_gm, mult_gm, [('output_image', 'second_input')]), (get_wm, relabel_wm, [('output_image', 'first_input')]), (sel_labels, me_csf, [('out_csf', 'op1')]), (mult_gm, add_gm, [('output_product_image', 'op1')]), (me_csf, add_gm, [('output_image', 'op2')]), (add_gm, relabel_gm, [('output_image', 'first_input')]), (relabel_wm, add_gm_wm, [('output_product_image', 'op1')]), (relabel_gm, add_gm_wm, [('output_product_image', 'op2')]), (add_gm_wm, sel_labels2, [('output_image', 'in_segm')]), (sel_labels2, add_7, [('out_wm', 'op1'), ('out_gm', 'op2')]), (add_7, me_7, [('output_image', 'op1')]), (me_7, comp_7, [('output_image', 'op1')]), (comp_7, md_7, [('output_image', 'op1')]), (md_7, fill_7, [('output_image', 'op1')]), (fill_7, add_7_2, [('output_image', 'op1')]), (pad_mask, add_7_2, [('output_image', 'op2')]), (add_7_2, md_7_2, [('output_image', 'op1')]), (md_7_2, me_7_2, [('output_image', 'op1')]), (me_7_2, depad_mask, [('output_image', 'op1')]), (add_gm_wm, depad_segm, [('output_image', 'op1')]), (relabel_wm, depad_wm, [('output_product_image', 'op1')]), (relabel_gm, depad_gm, [('output_product_image', 'op1')]), (sel_labels, depad_csf, [('out_csf', 'op1')]), (depad_csf, merge_tpms, [('output_image', 'in1')]), (depad_gm, merge_tpms, [('output_image', 'in2')]), (depad_wm, merge_tpms, [('output_image', 'in3')]), (depad_mask, msk_conform, [('output_image', 'in_mask')]), (msk_conform, copy_xform, [('out', 'out_mask')]), (depad_segm, copy_xform, [('output_image', 'out_segm')]), (merge_tpms, copy_xform, [('out', 'out_tpms')]), (copy_xform, outputnode, [('out_mask', 'out_mask'), ('out_segm', 'out_segm'), ('out_tpms', 'out_tpms')]), ]) return wf
def BAWantsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix, CLUSTER_QUEUE, CLUSTER_QUEUE_LONG): """ Inputs:: inputspec.images : inputspec.fixed_image : inputspec.ListOfPassiveImagesDictionaries : inputspec.interpolationMapping : Outputs:: outputspec.template : outputspec.transforms_list : outputspec.passive_deformed_templates : """ TemplateBuildSingleIterationWF = pe.Workflow( name='antsRegistrationTemplateBuildSingleIterationWF_' + str(iterationPhasePrefix)) inputSpec = pe.Node( interface=util.IdentityInterface(fields=[ 'ListOfImagesDictionaries', 'registrationImageTypes', # 'maskRegistrationImageType', 'interpolationMapping', 'fixed_image' ]), run_without_submitting=True, name='inputspec') ## HACK: TODO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0 ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that ## they do not change due to re-indenting. Otherwise re-indenting for flow control will trigger ## their hash to change. ## HACK: TODO: REMOVE 'transforms_list' it is not used. That will change all the hashes ## HACK: TODO: Need to run all python files through the code beutifiers. It has gotten pretty ugly. outputSpec = pe.Node(interface=util.IdentityInterface( fields=['template', 'transforms_list', 'passive_deformed_templates']), run_without_submitting=True, name='outputspec') ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template BeginANTS = pe.MapNode(interface=Registration(), name='BeginANTS', iterfield=['moving_image']) # SEE template.py many_cpu_BeginANTS_options_dictionary = {'qsub_args': modify_qsub_args(CLUSTER_QUEUE,4,2,8), 'overwrite': True} ## This is set in the template.py file BeginANTS.plugin_args = BeginANTS_cpu_sge_options_dictionary CommonANTsRegistrationSettings( antsRegistrationNode=BeginANTS, registrationTypeDescription="SixStageAntsRegistrationT1Only", output_transform_prefix=str(iterationPhasePrefix) + '_tfm', output_warped_image='atlas2subject.nii.gz', output_inverse_warped_image='subject2atlas.nii.gz', save_state='SavedantsRegistrationNodeSyNState.h5', invert_initial_moving_transform=False, initial_moving_transform=None) GetMovingImagesNode = pe.Node(interface=util.Function( function=GetMovingImages, input_names=[ 'ListOfImagesDictionaries', 'registrationImageTypes', 'interpolationMapping' ], output_names=['moving_images', 'moving_interpolation_type']), run_without_submitting=True, name='99_GetMovingImagesNode') TemplateBuildSingleIterationWF.connect(inputSpec, 'ListOfImagesDictionaries', GetMovingImagesNode, 'ListOfImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes', GetMovingImagesNode, 'registrationImageTypes') TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', GetMovingImagesNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_images', BeginANTS, 'moving_image') TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_interpolation_type', BeginANTS, 'interpolation') TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS, 'fixed_image') ## Now warp all the input_images images wimtdeformed = pe.MapNode( interface=ApplyTransforms(), iterfield=['transforms', 'input_image'], # iterfield=['transforms', 'invert_transform_flags', 'input_image'], name='wimtdeformed') wimtdeformed.inputs.interpolation = 'Linear' wimtdeformed.default_value = 0 # HACK: Should try using forward_composite_transform ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms') TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform', wimtdeformed, 'transforms') ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags') ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms # wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False] TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_images', wimtdeformed, 'input_image') TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', wimtdeformed, 'reference_image') ## Shape Update Next ===== ## Now Average All input_images deformed images together to create an updated template average AvgDeformedImages = pe.Node(interface=AverageImages(), name='AvgDeformedImages') AvgDeformedImages.inputs.dimension = 3 AvgDeformedImages.inputs.output_average_image = str( iterationPhasePrefix) + '.nii.gz' AvgDeformedImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image", AvgDeformedImages, 'images') ## Now average all affine transforms together AvgAffineTransform = pe.Node(interface=AverageAffineTransform(), name='AvgAffineTransform') AvgAffineTransform.inputs.dimension = 3 AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str( iterationPhasePrefix) + '_Affine.h5' SplitCompositeTransform = pe.MapNode(interface=util.Function( function=SplitCompositeToComponentTransforms, input_names=['transformFilename'], output_names=['affine_component_list', 'warp_component_list']), iterfield=['transformFilename'], run_without_submitting=True, name='99_SplitCompositeTransform') TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform', SplitCompositeTransform, 'transformFilename') ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'transformFilename') TemplateBuildSingleIterationWF.connect(SplitCompositeTransform, 'affine_component_list', AvgAffineTransform, 'transforms') ## Now average the warp fields togther AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages') AvgWarpImages.inputs.dimension = 3 AvgWarpImages.inputs.output_average_image = str( iterationPhasePrefix) + 'warp.nii.gz' AvgWarpImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect(SplitCompositeTransform, 'warp_component_list', AvgWarpImages, 'images') ## Now average the images together ## TODO: For now GradientStep is set to 0.25 as a hard coded default value. GradientStep = 0.25 GradientStepWarpImage = pe.Node(interface=MultiplyImages(), name='GradientStepWarpImage') GradientStepWarpImage.inputs.dimension = 3 GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str( iterationPhasePrefix) + '_warp.nii.gz' TemplateBuildSingleIterationWF.connect(AvgWarpImages, 'output_average_image', GradientStepWarpImage, 'first_input') ## Now create the new template shape based on the average of all deformed images UpdateTemplateShape = pe.Node(interface=ApplyTransforms(), name='UpdateTemplateShape') UpdateTemplateShape.inputs.invert_transform_flags = [True] UpdateTemplateShape.inputs.interpolation = 'Linear' UpdateTemplateShape.default_value = 0 TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', UpdateTemplateShape, 'reference_image') TemplateBuildSingleIterationWF.connect([ (AvgAffineTransform, UpdateTemplateShape, [(('affine_transform', makeListOfOneElement), 'transforms')]), ]) TemplateBuildSingleIterationWF.connect(GradientStepWarpImage, 'output_product_image', UpdateTemplateShape, 'input_image') ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node( interface=util.Function( function=MakeTransformListWithGradientWarps, input_names=['averageAffineTranform', 'gradientStepWarp'], output_names=['TransformListWithGradientWarps']), run_without_submitting=True, name='99_MakeTransformListWithGradientWarps') ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True TemplateBuildSingleIterationWF.connect( AvgAffineTransform, 'affine_transform', ApplyInvAverageAndFourTimesGradientStepWarpImage, 'averageAffineTranform') TemplateBuildSingleIterationWF.connect( UpdateTemplateShape, 'output_image', ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp') ReshapeAverageImageWithShapeUpdate = pe.Node( interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate') ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [ True, False, False, False, False ] ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear' ReshapeAverageImageWithShapeUpdate.default_value = 0 ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz' TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', ReshapeAverageImageWithShapeUpdate, 'input_image') TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', ReshapeAverageImageWithShapeUpdate, 'reference_image') TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, 'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate, 'transforms') TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate, 'output_image', outputSpec, 'template') ###### ###### ###### Process all the passive deformed images in a way similar to the main image used for registration ###### ###### ###### ############################################## ## Now warp all the ListOfPassiveImagesDictionaries images FlattenTransformAndImagesListNode = pe.Node( Function(function=FlattenTransformAndImagesList, input_names=[ 'ListOfPassiveImagesDictionaries', 'transforms', 'interpolationMapping', 'invert_transform_flags' ], output_names=[ 'flattened_images', 'flattened_transforms', 'flattened_invert_transform_flags', 'flattened_image_nametypes', 'flattened_interpolation_type' ]), run_without_submitting=True, name="99_FlattenTransformAndImagesList") GetPassiveImagesNode = pe.Node(interface=util.Function( function=GetPassiveImages, input_names=['ListOfImagesDictionaries', 'registrationImageTypes'], output_names=['ListOfPassiveImagesDictionaries']), run_without_submitting=True, name='99_GetPassiveImagesNode') TemplateBuildSingleIterationWF.connect(inputSpec, 'ListOfImagesDictionaries', GetPassiveImagesNode, 'ListOfImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes', GetPassiveImagesNode, 'registrationImageTypes') TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode, 'ListOfPassiveImagesDictionaries', FlattenTransformAndImagesListNode, 'ListOfPassiveImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', FlattenTransformAndImagesListNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform', FlattenTransformAndImagesListNode, 'transforms') ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False] ## TODO: Please check of invert_transform_flags has a fixed number. ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags') wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(), iterfield=[ 'transforms', 'invert_transform_flags', 'input_image', 'interpolation' ], name='wimtPassivedeformed') wimtPassivedeformed.default_value = 0 TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', wimtPassivedeformed, 'reference_image') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_interpolation_type', wimtPassivedeformed, 'interpolation') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_images', wimtPassivedeformed, 'input_image') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_transforms', wimtPassivedeformed, 'transforms') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_invert_transform_flags', wimtPassivedeformed, 'invert_transform_flags') RenestDeformedPassiveImagesNode = pe.Node( Function(function=RenestDeformedPassiveImages, input_names=[ 'deformedPassiveImages', 'flattened_image_nametypes', 'interpolationMapping' ], output_names=[ 'nested_imagetype_list', 'outputAverageImageName_list', 'image_type_list', 'nested_interpolation_type' ]), run_without_submitting=True, name="99_RenestDeformedPassiveImages") TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', RenestDeformedPassiveImagesNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image', RenestDeformedPassiveImagesNode, 'deformedPassiveImages') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_image_nametypes', RenestDeformedPassiveImagesNode, 'flattened_image_nametypes') ## Now Average All passive input_images deformed images together to create an updated template average AvgDeformedPassiveImages = pe.MapNode( interface=AverageImages(), iterfield=['images', 'output_average_image'], name='AvgDeformedPassiveImages') AvgDeformedPassiveImages.inputs.dimension = 3 AvgDeformedPassiveImages.inputs.normalize = False TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode, "nested_imagetype_list", AvgDeformedPassiveImages, 'images') TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode, "outputAverageImageName_list", AvgDeformedPassiveImages, 'output_average_image') ## -- TODO: Now neeed to reshape all the passive images as well ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode( interface=ApplyTransforms(), iterfield=[ 'input_image', 'reference_image', 'output_image', 'interpolation' ], name='ReshapeAveragePassiveImageWithShapeUpdate') ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [ True, False, False, False, False ] ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0 TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, 'nested_interpolation_type', ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation') TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, 'outputAverageImageName_list', ReshapeAveragePassiveImageWithShapeUpdate, 'output_image') TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, 'output_average_image', ReshapeAveragePassiveImageWithShapeUpdate, 'input_image') TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, 'output_average_image', ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image') TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, 'TransformListWithGradientWarps', ReshapeAveragePassiveImageWithShapeUpdate, 'transforms') TemplateBuildSingleIterationWF.connect( ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec, 'passive_deformed_templates') return TemplateBuildSingleIterationWF
def init_atropos_wf( name="atropos_wf", use_random_seed=True, omp_nthreads=None, mem_gb=3.0, padding=10, in_segmentation_model=tuple(ATROPOS_MODELS["T1w"].values()), bspline_fitting_distance=200, wm_prior=False, ): """ Create an ANTs' ATROPOS workflow for brain tissue segmentation. Re-interprets supersteps 6 and 7 of ``antsBrainExtraction.sh``, which refine the mask previously computed with the spatial normalization to the template. The workflow also executes steps 8 and 9 of the brain extraction workflow. Workflow Graph .. workflow:: :graph2use: orig :simple_form: yes from niworkflows.anat.ants import init_atropos_wf wf = init_atropos_wf() Parameters ---------- name : str, optional Workflow name (default: "atropos_wf"). use_random_seed : bool Whether ATROPOS should generate a random seed based on the system's clock omp_nthreads : int Maximum number of threads an individual process may use mem_gb : float Estimated peak memory consumption of the most hungry nodes in the workflow padding : int Pad images with zeros before processing in_segmentation_model : tuple A k-means segmentation is run to find gray or white matter around the edge of the initial brain mask warped from the template. This produces a segmentation image with :math:`$K$` classes, ordered by mean intensity in increasing order. With this option, you can control :math:`$K$` and tell the script which classes represent CSF, gray and white matter. Format (K, csfLabel, gmLabel, wmLabel). Examples: ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default), ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1, ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2, ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3. bspline_fitting_distance : float The size of the b-spline mesh grid elements, in mm (default: 200) wm_prior : :obj:`bool` Whether the WM posterior obtained with ATROPOS should be regularized with a prior map (typically, mapped from the template). When ``wm_prior`` is ``True`` the input field ``wm_prior`` of the input node must be connected. Inputs ------ in_files : list The original anatomical images passed in to the brain-extraction workflow. in_corrected : list :abbr:`INU (intensity non-uniformity)`-corrected files. in_mask : str Brain mask calculated previously. wm_prior : :obj:`str` Path to the WM prior probability map, aligned with the individual data. Outputs ------- out_file : :obj:`str` Path of the corrected and brain-extracted result, using the ATROPOS refinement. bias_corrected : :obj:`str` Path of the corrected and result, using the ATROPOS refinement. bias_image : :obj:`str` Path of the estimated INU bias field, using the ATROPOS refinement. out_mask : str Refined brain mask out_segm : str Output segmentation out_tpms : str Output :abbr:`TPMs (tissue probability maps)` """ wf = pe.Workflow(name) out_fields = [ "bias_corrected", "bias_image", "out_mask", "out_segm", "out_tpms" ] inputnode = pe.Node( niu.IdentityInterface( fields=["in_files", "in_corrected", "in_mask", "wm_prior"]), name="inputnode", ) outputnode = pe.Node(niu.IdentityInterface(fields=["out_file"] + out_fields), name="outputnode") copy_xform = pe.Node(CopyXForm(fields=out_fields), name="copy_xform", run_without_submitting=True) # Morphological dilation, radius=2 dil_brainmask = pe.Node(ImageMath(operation="MD", op2="2", copy_header=True), name="dil_brainmask") # Get largest connected component get_brainmask = pe.Node( ImageMath(operation="GetLargestComponent", copy_header=True), name="get_brainmask", ) # Run atropos (core node) atropos = pe.Node( Atropos( convergence_threshold=0.0, dimension=3, initialization="KMeans", likelihood_model="Gaussian", mrf_radius=[1, 1, 1], mrf_smoothing_factor=0.1, n_iterations=3, number_of_tissue_classes=in_segmentation_model[0], save_posteriors=True, use_random_seed=use_random_seed, ), name="01_atropos", n_procs=omp_nthreads, mem_gb=mem_gb, ) # massage outputs pad_segm = pe.Node( ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False), name="02_pad_segm", ) pad_mask = pe.Node( ImageMath(operation="PadImage", op2=f"{padding}", copy_header=False), name="03_pad_mask", ) # Split segmentation in binary masks sel_labels = pe.Node( niu.Function(function=_select_labels, output_names=["out_wm", "out_gm", "out_csf"]), name="04_sel_labels", ) sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:])) # Select largest components (GM, WM) # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM} get_wm = pe.Node(ImageMath(operation="GetLargestComponent"), name="05_get_wm") get_gm = pe.Node(ImageMath(operation="GetLargestComponent"), name="06_get_gm") # Fill holes and calculate intersection # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2 # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM} fill_gm = pe.Node(ImageMath(operation="FillHoles", op2="2"), name="07_fill_gm") mult_gm = pe.Node( MultiplyImages(dimension=3, output_product_image="08_mult_gm.nii.gz"), name="08_mult_gm", ) # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM} # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10 relabel_wm = pe.Node( MultiplyImages( dimension=3, second_input=in_segmentation_model[-1], output_product_image="09_relabel_wm.nii.gz", ), name="09_relabel_wm", ) me_csf = pe.Node(ImageMath(operation="ME", op2="10"), name="10_me_csf") # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP} # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM} # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM} add_gm = pe.Node(ImageMath(operation="addtozero"), name="11_add_gm") relabel_gm = pe.Node( MultiplyImages( dimension=3, second_input=in_segmentation_model[-2], output_product_image="12_relabel_gm.nii.gz", ), name="12_relabel_gm", ) add_gm_wm = pe.Node(ImageMath(operation="addtozero"), name="13_add_gm_wm") # Superstep 7 # Split segmentation in binary masks sel_labels2 = pe.Node( niu.Function(function=_select_labels, output_names=["out_gm", "out_wm"]), name="14_sel_labels2", ) sel_labels2.inputs.labels = in_segmentation_model[2:] # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP} add_7 = pe.Node(ImageMath(operation="addtozero"), name="15_add_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2 me_7 = pe.Node(ImageMath(operation="ME", op2="2"), name="16_me_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK} comp_7 = pe.Node(ImageMath(operation="GetLargestComponent"), name="17_comp_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4 md_7 = pe.Node(ImageMath(operation="MD", op2="4"), name="18_md_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2 fill_7 = pe.Node(ImageMath(operation="FillHoles", op2="2"), name="19_fill_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \ # ${EXTRACTION_MASK_PRIOR_WARPED} add_7_2 = pe.Node(ImageMath(operation="addtozero"), name="20_add_7_2") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5 md_7_2 = pe.Node(ImageMath(operation="MD", op2="5"), name="21_md_7_2") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5 me_7_2 = pe.Node(ImageMath(operation="ME", op2="5"), name="22_me_7_2") # De-pad depad_mask = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="23_depad_mask") depad_segm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="24_depad_segm") depad_gm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="25_depad_gm") depad_wm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="26_depad_wm") depad_csf = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="27_depad_csf") msk_conform = pe.Node(niu.Function(function=_conform_mask), name="msk_conform") merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]), name="merge_tpms") sel_wm = pe.Node(niu.Select(), name="sel_wm", run_without_submitting=True) if not wm_prior: sel_wm.inputs.index = in_segmentation_model[-1] - 1 copy_xform_wm = pe.Node(CopyXForm(fields=["wm_map"]), name="copy_xform_wm", run_without_submitting=True) # Refine INU correction inu_n4_final = pe.MapNode( N4BiasFieldCorrection( dimension=3, save_bias=True, copy_header=True, n_iterations=[50] * 5, convergence_threshold=1e-7, shrink_factor=4, bspline_fitting_distance=bspline_fitting_distance, ), n_procs=omp_nthreads, name="inu_n4_final", iterfield=["input_image"], ) try: inu_n4_final.inputs.rescale_intensities = True except ValueError: warn( "N4BiasFieldCorrection's --rescale-intensities option was added in ANTS 2.1.0 " f"({inu_n4_final.interface.version} found.) Please consider upgrading.", UserWarning, ) # Apply mask apply_mask = pe.MapNode(ApplyMask(), iterfield=["in_file"], name="apply_mask") # fmt: off wf.connect([ (inputnode, dil_brainmask, [("in_mask", "op1")]), (inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]), (inputnode, copy_xform_wm, [(("in_files", _pop), "hdr_file")]), (inputnode, pad_mask, [("in_mask", "op1")]), (inputnode, atropos, [("in_corrected", "intensity_images")]), (inputnode, inu_n4_final, [("in_files", "input_image")]), (inputnode, msk_conform, [(("in_files", _pop), "in_reference")]), (dil_brainmask, get_brainmask, [("output_image", "op1")]), (get_brainmask, atropos, [("output_image", "mask_image")]), (atropos, pad_segm, [("classified_image", "op1")]), (pad_segm, sel_labels, [("output_image", "in_segm")]), (sel_labels, get_wm, [("out_wm", "op1")]), (sel_labels, get_gm, [("out_gm", "op1")]), (get_gm, fill_gm, [("output_image", "op1")]), (get_gm, mult_gm, [("output_image", "first_input")]), (fill_gm, mult_gm, [("output_image", "second_input")]), (get_wm, relabel_wm, [("output_image", "first_input")]), (sel_labels, me_csf, [("out_csf", "op1")]), (mult_gm, add_gm, [("output_product_image", "op1")]), (me_csf, add_gm, [("output_image", "op2")]), (add_gm, relabel_gm, [("output_image", "first_input")]), (relabel_wm, add_gm_wm, [("output_product_image", "op1")]), (relabel_gm, add_gm_wm, [("output_product_image", "op2")]), (add_gm_wm, sel_labels2, [("output_image", "in_segm")]), (sel_labels2, add_7, [("out_wm", "op1"), ("out_gm", "op2")]), (add_7, me_7, [("output_image", "op1")]), (me_7, comp_7, [("output_image", "op1")]), (comp_7, md_7, [("output_image", "op1")]), (md_7, fill_7, [("output_image", "op1")]), (fill_7, add_7_2, [("output_image", "op1")]), (pad_mask, add_7_2, [("output_image", "op2")]), (add_7_2, md_7_2, [("output_image", "op1")]), (md_7_2, me_7_2, [("output_image", "op1")]), (me_7_2, depad_mask, [("output_image", "op1")]), (add_gm_wm, depad_segm, [("output_image", "op1")]), (relabel_wm, depad_wm, [("output_product_image", "op1")]), (relabel_gm, depad_gm, [("output_product_image", "op1")]), (sel_labels, depad_csf, [("out_csf", "op1")]), (depad_csf, merge_tpms, [("output_image", "in1")]), (depad_gm, merge_tpms, [("output_image", "in2")]), (depad_wm, merge_tpms, [("output_image", "in3")]), (depad_mask, msk_conform, [("output_image", "in_mask")]), (msk_conform, copy_xform, [("out", "out_mask")]), (depad_segm, copy_xform, [("output_image", "out_segm")]), (merge_tpms, copy_xform, [("out", "out_tpms")]), (atropos, sel_wm, [("posteriors", "inlist")]), (sel_wm, copy_xform_wm, [("out", "wm_map")]), (copy_xform_wm, inu_n4_final, [("wm_map", "weight_image")]), (inu_n4_final, copy_xform, [("output_image", "bias_corrected"), ("bias_image", "bias_image")]), (copy_xform, apply_mask, [("bias_corrected", "in_file"), ("out_mask", "in_mask")]), (apply_mask, outputnode, [("out_file", "out_file")]), (copy_xform, outputnode, [ ("bias_corrected", "bias_corrected"), ("bias_image", "bias_image"), ("out_mask", "out_mask"), ("out_segm", "out_segm"), ("out_tpms", "out_tpms"), ]), ]) # fmt: on if wm_prior: from nipype.algorithms.metrics import FuzzyOverlap def _argmax(in_dice): import numpy as np return np.argmax(in_dice) match_wm = pe.Node( niu.Function(function=_matchlen), name="match_wm", run_without_submitting=True, ) overlap = pe.Node(FuzzyOverlap(), name="overlap", run_without_submitting=True) apply_wm_prior = pe.Node(niu.Function(function=_improd), name="apply_wm_prior") # fmt: off wf.disconnect([ (copy_xform_wm, inu_n4_final, [("wm_map", "weight_image")]), ]) wf.connect([ (inputnode, apply_wm_prior, [("in_mask", "in_mask"), ("wm_prior", "op2")]), (inputnode, match_wm, [("wm_prior", "value")]), (atropos, match_wm, [("posteriors", "reference")]), (atropos, overlap, [("posteriors", "in_ref")]), (match_wm, overlap, [("out", "in_tst")]), (overlap, sel_wm, [(("class_fdi", _argmax), "index")]), (copy_xform_wm, apply_wm_prior, [("wm_map", "op1")]), (apply_wm_prior, inu_n4_final, [("out", "weight_image")]), ]) # fmt: on return wf
def BAWantsRegistrationTemplateBuildSingleIterationWF(iterationPhasePrefix=''): """ Inputs:: inputspec.images : inputspec.fixed_image : inputspec.ListOfPassiveImagesDictionaries : inputspec.interpolationMapping : Outputs:: outputspec.template : outputspec.transforms_list : outputspec.passive_deformed_templates : """ TemplateBuildSingleIterationWF = pe.Workflow( name='antsRegistrationTemplateBuildSingleIterationWF_' + str(iterationPhasePrefix)) inputSpec = pe.Node( interface=util.IdentityInterface(fields=[ 'ListOfImagesDictionaries', 'registrationImageTypes', #'maskRegistrationImageType', 'interpolationMapping', 'fixed_image' ]), run_without_submitting=True, name='inputspec') ## HACK: TODO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0 ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that ## they do not change due to re-indenting. Otherwise re-indenting for flow control will trigger ## their hash to change. ## HACK: TODO: REMOVE 'transforms_list' it is not used. That will change all the hashes ## HACK: TODO: Need to run all python files through the code beutifiers. It has gotten pretty ugly. outputSpec = pe.Node(interface=util.IdentityInterface( fields=['template', 'transforms_list', 'passive_deformed_templates']), run_without_submitting=True, name='outputspec') ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template BeginANTS = pe.MapNode(interface=Registration(), name='BeginANTS', iterfield=['moving_image']) BeginANTS.inputs.dimension = 3 """ This is the recommended set of parameters from the ANTS developers """ BeginANTS.inputs.output_transform_prefix = str( iterationPhasePrefix) + '_tfm' BeginANTS.inputs.transforms = ["Rigid", "Affine", "SyN", "SyN", "SyN"] BeginANTS.inputs.transform_parameters = [[0.1], [0.1], [0.1, 3.0, 0.0], [0.1, 3.0, 0.0], [0.1, 3.0, 0.0]] BeginANTS.inputs.metric = ['MI', 'MI', 'CC', 'CC', 'CC'] BeginANTS.inputs.sampling_strategy = [ 'Regular', 'Regular', None, None, None ] BeginANTS.inputs.sampling_percentage = [0.27, 0.27, 1.0, 1.0, 1.0] BeginANTS.inputs.metric_weight = [1.0, 1.0, 1.0, 1.0, 1.0] BeginANTS.inputs.radius_or_number_of_bins = [32, 32, 4, 4, 4] BeginANTS.inputs.number_of_iterations = [[1000, 1000, 1000, 1000], [1000, 1000, 1000, 1000], [1000, 250], [140], [25]] BeginANTS.inputs.convergence_threshold = [5e-8, 5e-8, 5e-7, 5e-6, 5e-5] BeginANTS.inputs.convergence_window_size = [10, 10, 10, 10, 10] BeginANTS.inputs.use_histogram_matching = [True, True, True, True, True] BeginANTS.inputs.shrink_factors = [[8, 4, 2, 1], [8, 4, 2, 1], [8, 4], [2], [1]] BeginANTS.inputs.smoothing_sigmas = [[3, 2, 1, 0], [3, 2, 1, 0], [3, 2], [1], [0]] BeginANTS.inputs.sigma_units = ["vox", "vox", "vox", "vox", "vox"] BeginANTS.inputs.use_estimate_learning_rate_once = [ False, False, False, False, False ] BeginANTS.inputs.write_composite_transform = True BeginANTS.inputs.collapse_output_transforms = False BeginANTS.inputs.initialize_transforms_per_stage = True BeginANTS.inputs.winsorize_lower_quantile = 0.01 BeginANTS.inputs.winsorize_upper_quantile = 0.99 BeginANTS.inputs.output_warped_image = 'atlas2subject.nii.gz' BeginANTS.inputs.output_inverse_warped_image = 'subject2atlas.nii.gz' BeginANTS.inputs.save_state = 'SavedBeginANTSSyNState.h5' BeginANTS.inputs.float = True GetMovingImagesNode = pe.Node(interface=util.Function( function=GetMovingImages, input_names=[ 'ListOfImagesDictionaries', 'registrationImageTypes', 'interpolationMapping' ], output_names=['moving_images', 'moving_interpolation_type']), run_without_submitting=True, name='99_GetMovingImagesNode') TemplateBuildSingleIterationWF.connect(inputSpec, 'ListOfImagesDictionaries', GetMovingImagesNode, 'ListOfImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes', GetMovingImagesNode, 'registrationImageTypes') TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', GetMovingImagesNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_images', BeginANTS, 'moving_image') TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_interpolation_type', BeginANTS, 'interpolation') TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS, 'fixed_image') ## Now warp all the input_images images wimtdeformed = pe.MapNode( interface=ApplyTransforms(), iterfield=['transforms', 'input_image'], #iterfield=['transforms', 'invert_transform_flags', 'input_image'], name='wimtdeformed') wimtdeformed.inputs.interpolation = 'Linear' wimtdeformed.default_value = 0 # HACK: Should try using forward_composite_transform ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms') TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform', wimtdeformed, 'transforms') ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags') ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms #wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False] TemplateBuildSingleIterationWF.connect(GetMovingImagesNode, 'moving_images', wimtdeformed, 'input_image') TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', wimtdeformed, 'reference_image') ## Shape Update Next ===== ## Now Average All input_images deformed images together to create an updated template average AvgDeformedImages = pe.Node(interface=AverageImages(), name='AvgDeformedImages') AvgDeformedImages.inputs.dimension = 3 AvgDeformedImages.inputs.output_average_image = str( iterationPhasePrefix) + '.nii.gz' AvgDeformedImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image", AvgDeformedImages, 'images') ## Now average all affine transforms together AvgAffineTransform = pe.Node(interface=AverageAffineTransform(), name='AvgAffineTransform') AvgAffineTransform.inputs.dimension = 3 AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str( iterationPhasePrefix) + '_Affine.h5' SplitCompositeTransform = pe.MapNode( interface=util.Function( function=SplitCompositeToComponentTransforms, input_names=['composite_transform_as_list'], output_names=['affine_component_list', 'warp_component_list']), iterfield=['composite_transform_as_list'], run_without_submitting=True, name='99_SplitCompositeTransform') TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform', SplitCompositeTransform, 'composite_transform_as_list') ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'composite_transform_as_list') TemplateBuildSingleIterationWF.connect(SplitCompositeTransform, 'affine_component_list', AvgAffineTransform, 'transforms') ## Now average the warp fields togther AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages') AvgWarpImages.inputs.dimension = 3 AvgWarpImages.inputs.output_average_image = str( iterationPhasePrefix) + 'warp.nii.gz' AvgWarpImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect(SplitCompositeTransform, 'warp_component_list', AvgWarpImages, 'images') ## Now average the images together ## TODO: For now GradientStep is set to 0.25 as a hard coded default value. GradientStep = 0.25 GradientStepWarpImage = pe.Node(interface=MultiplyImages(), name='GradientStepWarpImage') GradientStepWarpImage.inputs.dimension = 3 GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str( iterationPhasePrefix) + '_warp.nii.gz' TemplateBuildSingleIterationWF.connect(AvgWarpImages, 'output_average_image', GradientStepWarpImage, 'first_input') ## Now create the new template shape based on the average of all deformed images UpdateTemplateShape = pe.Node(interface=ApplyTransforms(), name='UpdateTemplateShape') UpdateTemplateShape.inputs.invert_transform_flags = [True] UpdateTemplateShape.inputs.interpolation = 'Linear' UpdateTemplateShape.default_value = 0 TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', UpdateTemplateShape, 'reference_image') TemplateBuildSingleIterationWF.connect([ (AvgAffineTransform, UpdateTemplateShape, [(('affine_transform', makeListOfOneElement), 'transforms')]), ]) TemplateBuildSingleIterationWF.connect(GradientStepWarpImage, 'output_product_image', UpdateTemplateShape, 'input_image') ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node( interface=util.Function( function=MakeTransformListWithGradientWarps, input_names=['averageAffineTranform', 'gradientStepWarp'], output_names=['TransformListWithGradientWarps']), run_without_submitting=True, name='99_MakeTransformListWithGradientWarps') ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True TemplateBuildSingleIterationWF.connect( AvgAffineTransform, 'affine_transform', ApplyInvAverageAndFourTimesGradientStepWarpImage, 'averageAffineTranform') TemplateBuildSingleIterationWF.connect( UpdateTemplateShape, 'output_image', ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp') ReshapeAverageImageWithShapeUpdate = pe.Node( interface=ApplyTransforms(), name='ReshapeAverageImageWithShapeUpdate') ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [ True, False, False, False, False ] ReshapeAverageImageWithShapeUpdate.inputs.interpolation = 'Linear' ReshapeAverageImageWithShapeUpdate.default_value = 0 ReshapeAverageImageWithShapeUpdate.inputs.output_image = 'ReshapeAverageImageWithShapeUpdate.nii.gz' TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', ReshapeAverageImageWithShapeUpdate, 'input_image') TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', ReshapeAverageImageWithShapeUpdate, 'reference_image') TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, 'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate, 'transforms') TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate, 'output_image', outputSpec, 'template') ###### ###### ###### Process all the passive deformed images in a way similar to the main image used for registration ###### ###### ###### ############################################## ## Now warp all the ListOfPassiveImagesDictionaries images FlattenTransformAndImagesListNode = pe.Node( Function(function=FlattenTransformAndImagesList, input_names=[ 'ListOfPassiveImagesDictionaries', 'transforms', 'interpolationMapping', 'invert_transform_flags' ], output_names=[ 'flattened_images', 'flattened_transforms', 'flattened_invert_transform_flags', 'flattened_image_nametypes', 'flattened_interpolation_type' ]), run_without_submitting=True, name="99_FlattenTransformAndImagesList") GetPassiveImagesNode = pe.Node(interface=util.Function( function=GetPassiveImages, input_names=['ListOfImagesDictionaries', 'registrationImageTypes'], output_names=['ListOfPassiveImagesDictionaries']), run_without_submitting=True, name='99_GetPassiveImagesNode') TemplateBuildSingleIterationWF.connect(inputSpec, 'ListOfImagesDictionaries', GetPassiveImagesNode, 'ListOfImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'registrationImageTypes', GetPassiveImagesNode, 'registrationImageTypes') TemplateBuildSingleIterationWF.connect(GetPassiveImagesNode, 'ListOfPassiveImagesDictionaries', FlattenTransformAndImagesListNode, 'ListOfPassiveImagesDictionaries') TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', FlattenTransformAndImagesListNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(BeginANTS, 'composite_transform', FlattenTransformAndImagesListNode, 'transforms') ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False] ## TODO: Please check of invert_transform_flags has a fixed number. ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags') wimtPassivedeformed = pe.MapNode(interface=ApplyTransforms(), iterfield=[ 'transforms', 'invert_transform_flags', 'input_image', 'interpolation' ], name='wimtPassivedeformed') wimtPassivedeformed.default_value = 0 TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', wimtPassivedeformed, 'reference_image') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_interpolation_type', wimtPassivedeformed, 'interpolation') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_images', wimtPassivedeformed, 'input_image') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_transforms', wimtPassivedeformed, 'transforms') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_invert_transform_flags', wimtPassivedeformed, 'invert_transform_flags') RenestDeformedPassiveImagesNode = pe.Node( Function(function=RenestDeformedPassiveImages, input_names=[ 'deformedPassiveImages', 'flattened_image_nametypes', 'interpolationMapping' ], output_names=[ 'nested_imagetype_list', 'outputAverageImageName_list', 'image_type_list', 'nested_interpolation_type' ]), run_without_submitting=True, name="99_RenestDeformedPassiveImages") TemplateBuildSingleIterationWF.connect(inputSpec, 'interpolationMapping', RenestDeformedPassiveImagesNode, 'interpolationMapping') TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image', RenestDeformedPassiveImagesNode, 'deformedPassiveImages') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_image_nametypes', RenestDeformedPassiveImagesNode, 'flattened_image_nametypes') ## Now Average All passive input_images deformed images together to create an updated template average AvgDeformedPassiveImages = pe.MapNode( interface=AverageImages(), iterfield=['images', 'output_average_image'], name='AvgDeformedPassiveImages') AvgDeformedPassiveImages.inputs.dimension = 3 AvgDeformedPassiveImages.inputs.normalize = False TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode, "nested_imagetype_list", AvgDeformedPassiveImages, 'images') TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode, "outputAverageImageName_list", AvgDeformedPassiveImages, 'output_average_image') ## -- TODO: Now neeed to reshape all the passive images as well ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode( interface=ApplyTransforms(), iterfield=[ 'input_image', 'reference_image', 'output_image', 'interpolation' ], name='ReshapeAveragePassiveImageWithShapeUpdate') ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [ True, False, False, False, False ] ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0 TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, 'nested_interpolation_type', ReshapeAveragePassiveImageWithShapeUpdate, 'interpolation') TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, 'outputAverageImageName_list', ReshapeAveragePassiveImageWithShapeUpdate, 'output_image') TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, 'output_average_image', ReshapeAveragePassiveImageWithShapeUpdate, 'input_image') TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, 'output_average_image', ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image') TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, 'TransformListWithGradientWarps', ReshapeAveragePassiveImageWithShapeUpdate, 'transforms') TemplateBuildSingleIterationWF.connect( ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec, 'passive_deformed_templates') return TemplateBuildSingleIterationWF
def ANTSTemplateBuildSingleIterationWF(iterationPhasePrefix=''): """ Inputs:: inputspec.images : inputspec.fixed_image : inputspec.ListOfPassiveImagesDictionaries : Outputs:: outputspec.template : outputspec.transforms_list : outputspec.passive_deformed_templates : """ TemplateBuildSingleIterationWF = pe.Workflow( name='ANTSTemplateBuildSingleIterationWF_' + str(str(iterationPhasePrefix))) inputSpec = pe.Node(interface=util.IdentityInterface( fields=['images', 'fixed_image', 'ListOfPassiveImagesDictionaries']), run_without_submitting=True, name='inputspec') ## HACK: TODO: Need to move all local functions to a common untility file, or at the top of the file so that ## they do not change due to re-indenting. Otherwise re-indenting for flow control will trigger ## their hash to change. ## HACK: TODO: REMOVE 'transforms_list' it is not used. That will change all the hashes ## HACK: TODO: Need to run all python files through the code beutifiers. It has gotten pretty ugly. outputSpec = pe.Node(interface=util.IdentityInterface( fields=['template', 'transforms_list', 'passive_deformed_templates']), run_without_submitting=True, name='outputspec') ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template BeginANTS = pe.MapNode(interface=ANTS(), name='BeginANTS', iterfield=['moving_image']) BeginANTS.inputs.dimension = 3 BeginANTS.inputs.output_transform_prefix = str( iterationPhasePrefix) + '_tfm' BeginANTS.inputs.metric = ['CC'] BeginANTS.inputs.metric_weight = [1.0] BeginANTS.inputs.radius = [5] BeginANTS.inputs.transformation_model = 'SyN' BeginANTS.inputs.gradient_step_length = 0.25 BeginANTS.inputs.number_of_iterations = [50, 35, 15] BeginANTS.inputs.number_of_affine_iterations = [ 10000, 10000, 10000, 10000, 10000 ] BeginANTS.inputs.use_histogram_matching = True BeginANTS.inputs.mi_option = [32, 16000] BeginANTS.inputs.regularization = 'Gauss' BeginANTS.inputs.regularization_gradient_field_sigma = 3 BeginANTS.inputs.regularization_deformation_field_sigma = 0 TemplateBuildSingleIterationWF.connect(inputSpec, 'images', BeginANTS, 'moving_image') TemplateBuildSingleIterationWF.connect(inputSpec, 'fixed_image', BeginANTS, 'fixed_image') MakeTransformsLists = pe.Node(interface=util.Function( function=MakeListsOfTransformLists, input_names=['warpTransformList', 'AffineTransformList'], output_names=['out']), run_without_submitting=True, name='MakeTransformsLists') MakeTransformsLists.inputs.ignore_exception = True TemplateBuildSingleIterationWF.connect(BeginANTS, 'warp_transform', MakeTransformsLists, 'warpTransformList') TemplateBuildSingleIterationWF.connect(BeginANTS, 'affine_transform', MakeTransformsLists, 'AffineTransformList') ## Now warp all the input_images images wimtdeformed = pe.MapNode( interface=WarpImageMultiTransform(), iterfield=['transformation_series', 'input_image'], name='wimtdeformed') TemplateBuildSingleIterationWF.connect(inputSpec, 'images', wimtdeformed, 'input_image') TemplateBuildSingleIterationWF.connect(MakeTransformsLists, 'out', wimtdeformed, 'transformation_series') ## Shape Update Next ===== ## Now Average All input_images deformed images together to create an updated template average AvgDeformedImages = pe.Node(interface=AverageImages(), name='AvgDeformedImages') AvgDeformedImages.inputs.dimension = 3 AvgDeformedImages.inputs.output_average_image = str( iterationPhasePrefix) + '.nii.gz' AvgDeformedImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect(wimtdeformed, "output_image", AvgDeformedImages, 'images') ## Now average all affine transforms together AvgAffineTransform = pe.Node(interface=AverageAffineTransform(), name='AvgAffineTransform') AvgAffineTransform.inputs.dimension = 3 AvgAffineTransform.inputs.output_affine_transform = 'Avererage_' + str( iterationPhasePrefix) + '_Affine.mat' TemplateBuildSingleIterationWF.connect(BeginANTS, 'affine_transform', AvgAffineTransform, 'transforms') ## Now average the warp fields togther AvgWarpImages = pe.Node(interface=AverageImages(), name='AvgWarpImages') AvgWarpImages.inputs.dimension = 3 AvgWarpImages.inputs.output_average_image = str( iterationPhasePrefix) + 'warp.nii.gz' AvgWarpImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect(BeginANTS, 'warp_transform', AvgWarpImages, 'images') ## Now average the images together ## TODO: For now GradientStep is set to 0.25 as a hard coded default value. GradientStep = 0.25 GradientStepWarpImage = pe.Node(interface=MultiplyImages(), name='GradientStepWarpImage') GradientStepWarpImage.inputs.dimension = 3 GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep GradientStepWarpImage.inputs.output_product_image = 'GradientStep0.25_' + str( iterationPhasePrefix) + '_warp.nii.gz' TemplateBuildSingleIterationWF.connect(AvgWarpImages, 'output_average_image', GradientStepWarpImage, 'first_input') ## Now create the new template shape based on the average of all deformed images UpdateTemplateShape = pe.Node(interface=WarpImageMultiTransform(), name='UpdateTemplateShape') UpdateTemplateShape.inputs.invert_affine = [1] TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', UpdateTemplateShape, 'reference_image') TemplateBuildSingleIterationWF.connect(AvgAffineTransform, 'affine_transform', UpdateTemplateShape, 'transformation_series') TemplateBuildSingleIterationWF.connect(GradientStepWarpImage, 'output_product_image', UpdateTemplateShape, 'input_image') ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node( interface=util.Function( function=MakeTransformListWithGradientWarps, input_names=['averageAffineTranform', 'gradientStepWarp'], output_names=['TransformListWithGradientWarps']), run_without_submitting=True, name='MakeTransformListWithGradientWarps') ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True TemplateBuildSingleIterationWF.connect( AvgAffineTransform, 'affine_transform', ApplyInvAverageAndFourTimesGradientStepWarpImage, 'averageAffineTranform') TemplateBuildSingleIterationWF.connect( UpdateTemplateShape, 'output_image', ApplyInvAverageAndFourTimesGradientStepWarpImage, 'gradientStepWarp') ReshapeAverageImageWithShapeUpdate = pe.Node( interface=WarpImageMultiTransform(), name='ReshapeAverageImageWithShapeUpdate') ReshapeAverageImageWithShapeUpdate.inputs.invert_affine = [1] ReshapeAverageImageWithShapeUpdate.inputs.out_postfix = '_Reshaped' TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', ReshapeAverageImageWithShapeUpdate, 'input_image') TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', ReshapeAverageImageWithShapeUpdate, 'reference_image') TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, 'TransformListWithGradientWarps', ReshapeAverageImageWithShapeUpdate, 'transformation_series') TemplateBuildSingleIterationWF.connect(ReshapeAverageImageWithShapeUpdate, 'output_image', outputSpec, 'template') ###### ###### ###### Process all the passive deformed images in a way similar to the main image used for registration ###### ###### ###### ############################################## ## Now warp all the ListOfPassiveImagesDictionaries images FlattenTransformAndImagesListNode = pe.Node( Function(function=FlattenTransformAndImagesList, input_names=[ 'ListOfPassiveImagesDictionaries', 'transformation_series' ], output_names=[ 'flattened_images', 'flattened_transforms', 'flattened_image_nametypes' ]), run_without_submitting=True, name="99_FlattenTransformAndImagesList") TemplateBuildSingleIterationWF.connect(inputSpec, 'ListOfPassiveImagesDictionaries', FlattenTransformAndImagesListNode, 'ListOfPassiveImagesDictionaries') TemplateBuildSingleIterationWF.connect(MakeTransformsLists, 'out', FlattenTransformAndImagesListNode, 'transformation_series') wimtPassivedeformed = pe.MapNode( interface=WarpImageMultiTransform(), iterfield=['transformation_series', 'input_image'], name='wimtPassivedeformed') TemplateBuildSingleIterationWF.connect(AvgDeformedImages, 'output_average_image', wimtPassivedeformed, 'reference_image') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_images', wimtPassivedeformed, 'input_image') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_transforms', wimtPassivedeformed, 'transformation_series') RenestDeformedPassiveImagesNode = pe.Node( Function( function=RenestDeformedPassiveImages, input_names=['deformedPassiveImages', 'flattened_image_nametypes'], output_names=[ 'nested_imagetype_list', 'outputAverageImageName_list', 'image_type_list' ]), run_without_submitting=True, name="99_RenestDeformedPassiveImages") TemplateBuildSingleIterationWF.connect(wimtPassivedeformed, 'output_image', RenestDeformedPassiveImagesNode, 'deformedPassiveImages') TemplateBuildSingleIterationWF.connect(FlattenTransformAndImagesListNode, 'flattened_image_nametypes', RenestDeformedPassiveImagesNode, 'flattened_image_nametypes') ## Now Average All passive input_images deformed images together to create an updated template average AvgDeformedPassiveImages = pe.MapNode( interface=AverageImages(), iterfield=['images', 'output_average_image'], name='AvgDeformedPassiveImages') AvgDeformedPassiveImages.inputs.dimension = 3 AvgDeformedPassiveImages.inputs.normalize = False TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode, "nested_imagetype_list", AvgDeformedPassiveImages, 'images') TemplateBuildSingleIterationWF.connect(RenestDeformedPassiveImagesNode, "outputAverageImageName_list", AvgDeformedPassiveImages, 'output_average_image') ## -- TODO: Now neeed to reshape all the passive images as well ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode( interface=WarpImageMultiTransform(), iterfield=['input_image', 'reference_image', 'out_postfix'], name='ReshapeAveragePassiveImageWithShapeUpdate') ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_affine = [1] TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, "image_type_list", ReshapeAveragePassiveImageWithShapeUpdate, 'out_postfix') TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, 'output_average_image', ReshapeAveragePassiveImageWithShapeUpdate, 'input_image') TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, 'output_average_image', ReshapeAveragePassiveImageWithShapeUpdate, 'reference_image') TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, 'TransformListWithGradientWarps', ReshapeAveragePassiveImageWithShapeUpdate, 'transformation_series') TemplateBuildSingleIterationWF.connect( ReshapeAveragePassiveImageWithShapeUpdate, 'output_image', outputSpec, 'passive_deformed_templates') return TemplateBuildSingleIterationWF
def baw_ants_registration_template_build_single_iteration_wf( iterationPhasePrefix, CLUSTER_QUEUE, CLUSTER_QUEUE_LONG ): """ Inputs:: inputspec.images : inputspec.fixed_image : inputspec.ListOfPassiveImagesDictionaries : inputspec.interpolationMapping : Outputs:: outputspec.template : outputspec.transforms_list : outputspec.passive_deformed_templates : """ TemplateBuildSingleIterationWF = pe.Workflow( name="antsRegistrationTemplateBuildSingleIterationWF_" + str(iterationPhasePrefix) ) inputSpec = pe.Node( interface=util.IdentityInterface( fields=[ "ListOfImagesDictionaries", "registrationImageTypes", # 'maskRegistrationImageType', "interpolationMapping", "fixed_image", ] ), run_without_submitting=True, name="inputspec", ) ## HACK: INFO: We need to have the AVG_AIR.nii.gz be warped with a default voxel value of 1.0 ## HACK: INFO: Need to move all local functions to a common untility file, or at the top of the file so that ## they do not change due to re-indenting. Otherwise re-indenting for flow control will trigger ## their hash to change. ## HACK: INFO: REMOVE 'transforms_list' it is not used. That will change all the hashes ## HACK: INFO: Need to run all python files through the code beutifiers. It has gotten pretty ugly. outputSpec = pe.Node( interface=util.IdentityInterface( fields=["template", "transforms_list", "passive_deformed_templates"] ), run_without_submitting=True, name="outputspec", ) ### NOTE MAP NODE! warp each of the original images to the provided fixed_image as the template BeginANTS = pe.MapNode( interface=Registration(), name="BeginANTS", iterfield=["moving_image"] ) # SEE template.py many_cpu_BeginANTS_options_dictionary = {'qsub_args': modify_qsub_args(CLUSTER_QUEUE,4,2,8), 'overwrite': True} ## This is set in the template.py file BeginANTS.plugin_args = BeginANTS_cpu_sge_options_dictionary common_ants_registration_settings( antsRegistrationNode=BeginANTS, registrationTypeDescription="SixStageAntsRegistrationT1Only", output_transform_prefix=str(iterationPhasePrefix) + "_tfm", output_warped_image="atlas2subject.nii.gz", output_inverse_warped_image="subject2atlas.nii.gz", save_state="SavedantsRegistrationNodeSyNState.h5", invert_initial_moving_transform=False, initial_moving_transform=None, ) GetMovingImagesNode = pe.Node( interface=util.Function( function=get_moving_images, input_names=[ "ListOfImagesDictionaries", "registrationImageTypes", "interpolationMapping", ], output_names=["moving_images", "moving_interpolation_type"], ), run_without_submitting=True, name="99_GetMovingImagesNode", ) TemplateBuildSingleIterationWF.connect( inputSpec, "ListOfImagesDictionaries", GetMovingImagesNode, "ListOfImagesDictionaries", ) TemplateBuildSingleIterationWF.connect( inputSpec, "registrationImageTypes", GetMovingImagesNode, "registrationImageTypes", ) TemplateBuildSingleIterationWF.connect( inputSpec, "interpolationMapping", GetMovingImagesNode, "interpolationMapping" ) TemplateBuildSingleIterationWF.connect( GetMovingImagesNode, "moving_images", BeginANTS, "moving_image" ) TemplateBuildSingleIterationWF.connect( GetMovingImagesNode, "moving_interpolation_type", BeginANTS, "interpolation" ) TemplateBuildSingleIterationWF.connect( inputSpec, "fixed_image", BeginANTS, "fixed_image" ) ## Now warp all the input_images images wimtdeformed = pe.MapNode( interface=ApplyTransforms(), iterfield=["transforms", "input_image"], # iterfield=['transforms', 'invert_transform_flags', 'input_image'], name="wimtdeformed", ) wimtdeformed.inputs.interpolation = "Linear" wimtdeformed.default_value = 0 # HACK: Should try using forward_composite_transform ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transform', wimtdeformed, 'transforms') TemplateBuildSingleIterationWF.connect( BeginANTS, "composite_transform", wimtdeformed, "transforms" ) ##PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', wimtdeformed, 'invert_transform_flags') ## NOTE: forward_invert_flags:: List of flags corresponding to the forward transforms # wimtdeformed.inputs.invert_transform_flags = [False,False,False,False,False] TemplateBuildSingleIterationWF.connect( GetMovingImagesNode, "moving_images", wimtdeformed, "input_image" ) TemplateBuildSingleIterationWF.connect( inputSpec, "fixed_image", wimtdeformed, "reference_image" ) ## Shape Update Next ===== ## Now Average All input_images deformed images together to create an updated template average AvgDeformedImages = pe.Node(interface=AverageImages(), name="AvgDeformedImages") AvgDeformedImages.inputs.dimension = 3 AvgDeformedImages.inputs.output_average_image = ( str(iterationPhasePrefix) + ".nii.gz" ) AvgDeformedImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect( wimtdeformed, "output_image", AvgDeformedImages, "images" ) ## Now average all affine transforms together AvgAffineTransform = pe.Node( interface=AverageAffineTransform(), name="AvgAffineTransform" ) AvgAffineTransform.inputs.dimension = 3 AvgAffineTransform.inputs.output_affine_transform = ( "Avererage_" + str(iterationPhasePrefix) + "_Affine.h5" ) SplitCompositeTransform = pe.MapNode( interface=util.Function( function=split_composite_to_component_transform, input_names=["transformFilename"], output_names=["affine_component_list", "warp_component_list"], ), iterfield=["transformFilename"], run_without_submitting=True, name="99_SplitCompositeTransform", ) TemplateBuildSingleIterationWF.connect( BeginANTS, "composite_transform", SplitCompositeTransform, "transformFilename" ) ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_transforms', SplitCompositeTransform, 'transformFilename') TemplateBuildSingleIterationWF.connect( SplitCompositeTransform, "affine_component_list", AvgAffineTransform, "transforms", ) ## Now average the warp fields togther AvgWarpImages = pe.Node(interface=AverageImages(), name="AvgWarpImages") AvgWarpImages.inputs.dimension = 3 AvgWarpImages.inputs.output_average_image = ( str(iterationPhasePrefix) + "warp.nii.gz" ) AvgWarpImages.inputs.normalize = True TemplateBuildSingleIterationWF.connect( SplitCompositeTransform, "warp_component_list", AvgWarpImages, "images" ) ## Now average the images together ## INFO: For now GradientStep is set to 0.25 as a hard coded default value. GradientStep = 0.25 GradientStepWarpImage = pe.Node( interface=MultiplyImages(), name="GradientStepWarpImage" ) GradientStepWarpImage.inputs.dimension = 3 GradientStepWarpImage.inputs.second_input = -1.0 * GradientStep GradientStepWarpImage.inputs.output_product_image = ( "GradientStep0.25_" + str(iterationPhasePrefix) + "_warp.nii.gz" ) TemplateBuildSingleIterationWF.connect( AvgWarpImages, "output_average_image", GradientStepWarpImage, "first_input" ) ## Now create the new template shape based on the average of all deformed images UpdateTemplateShape = pe.Node( interface=ApplyTransforms(), name="UpdateTemplateShape" ) UpdateTemplateShape.inputs.invert_transform_flags = [True] UpdateTemplateShape.inputs.interpolation = "Linear" UpdateTemplateShape.default_value = 0 TemplateBuildSingleIterationWF.connect( AvgDeformedImages, "output_average_image", UpdateTemplateShape, "reference_image", ) TemplateBuildSingleIterationWF.connect( [ ( AvgAffineTransform, UpdateTemplateShape, [(("affine_transform", make_list_of_one_element), "transforms")], ) ] ) TemplateBuildSingleIterationWF.connect( GradientStepWarpImage, "output_product_image", UpdateTemplateShape, "input_image", ) ApplyInvAverageAndFourTimesGradientStepWarpImage = pe.Node( interface=util.Function( function=make_transform_list_with_gradient_warps, input_names=["averageAffineTranform", "gradientStepWarp"], output_names=["TransformListWithGradientWarps"], ), run_without_submitting=True, name="99_MakeTransformListWithGradientWarps", ) # ApplyInvAverageAndFourTimesGradientStepWarpImage.inputs.ignore_exception = True TemplateBuildSingleIterationWF.connect( AvgAffineTransform, "affine_transform", ApplyInvAverageAndFourTimesGradientStepWarpImage, "averageAffineTranform", ) TemplateBuildSingleIterationWF.connect( UpdateTemplateShape, "output_image", ApplyInvAverageAndFourTimesGradientStepWarpImage, "gradientStepWarp", ) ReshapeAverageImageWithShapeUpdate = pe.Node( interface=ApplyTransforms(), name="ReshapeAverageImageWithShapeUpdate" ) ReshapeAverageImageWithShapeUpdate.inputs.invert_transform_flags = [ True, False, False, False, False, ] ReshapeAverageImageWithShapeUpdate.inputs.interpolation = "Linear" ReshapeAverageImageWithShapeUpdate.default_value = 0 ReshapeAverageImageWithShapeUpdate.inputs.output_image = ( "ReshapeAverageImageWithShapeUpdate.nii.gz" ) TemplateBuildSingleIterationWF.connect( AvgDeformedImages, "output_average_image", ReshapeAverageImageWithShapeUpdate, "input_image", ) TemplateBuildSingleIterationWF.connect( AvgDeformedImages, "output_average_image", ReshapeAverageImageWithShapeUpdate, "reference_image", ) TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, "TransformListWithGradientWarps", ReshapeAverageImageWithShapeUpdate, "transforms", ) TemplateBuildSingleIterationWF.connect( ReshapeAverageImageWithShapeUpdate, "output_image", outputSpec, "template" ) ###### ###### ###### Process all the passive deformed images in a way similar to the main image used for registration ###### ###### ###### ############################################## ## Now warp all the ListOfPassiveImagesDictionaries images FlattenTransformAndImagesListNode = pe.Node( Function( function=flatten_transform_and_images_list, input_names=[ "ListOfPassiveImagesDictionaries", "transforms", "interpolationMapping", "invert_transform_flags", ], output_names=[ "flattened_images", "flattened_transforms", "flattened_invert_transform_flags", "flattened_image_nametypes", "flattened_interpolation_type", ], ), run_without_submitting=True, name="99_FlattenTransformAndImagesList", ) GetPassiveImagesNode = pe.Node( interface=util.Function( function=get_passive_images, input_names=["ListOfImagesDictionaries", "registrationImageTypes"], output_names=["ListOfPassiveImagesDictionaries"], ), run_without_submitting=True, name="99_GetPassiveImagesNode", ) TemplateBuildSingleIterationWF.connect( inputSpec, "ListOfImagesDictionaries", GetPassiveImagesNode, "ListOfImagesDictionaries", ) TemplateBuildSingleIterationWF.connect( inputSpec, "registrationImageTypes", GetPassiveImagesNode, "registrationImageTypes", ) TemplateBuildSingleIterationWF.connect( GetPassiveImagesNode, "ListOfPassiveImagesDictionaries", FlattenTransformAndImagesListNode, "ListOfPassiveImagesDictionaries", ) TemplateBuildSingleIterationWF.connect( inputSpec, "interpolationMapping", FlattenTransformAndImagesListNode, "interpolationMapping", ) TemplateBuildSingleIterationWF.connect( BeginANTS, "composite_transform", FlattenTransformAndImagesListNode, "transforms", ) ## FlattenTransformAndImagesListNode.inputs.invert_transform_flags = [False,False,False,False,False,False] ## INFO: Please check of invert_transform_flags has a fixed number. ## PREVIOUS TemplateBuildSingleIterationWF.connect(BeginANTS, 'forward_invert_flags', FlattenTransformAndImagesListNode, 'invert_transform_flags') wimtPassivedeformed = pe.MapNode( interface=ApplyTransforms(), iterfield=[ "transforms", "invert_transform_flags", "input_image", "interpolation", ], name="wimtPassivedeformed", ) wimtPassivedeformed.default_value = 0 TemplateBuildSingleIterationWF.connect( AvgDeformedImages, "output_average_image", wimtPassivedeformed, "reference_image", ) TemplateBuildSingleIterationWF.connect( FlattenTransformAndImagesListNode, "flattened_interpolation_type", wimtPassivedeformed, "interpolation", ) TemplateBuildSingleIterationWF.connect( FlattenTransformAndImagesListNode, "flattened_images", wimtPassivedeformed, "input_image", ) TemplateBuildSingleIterationWF.connect( FlattenTransformAndImagesListNode, "flattened_transforms", wimtPassivedeformed, "transforms", ) TemplateBuildSingleIterationWF.connect( FlattenTransformAndImagesListNode, "flattened_invert_transform_flags", wimtPassivedeformed, "invert_transform_flags", ) RenestDeformedPassiveImagesNode = pe.Node( Function( function=renest_deformed_passive_images, input_names=[ "deformedPassiveImages", "flattened_image_nametypes", "interpolationMapping", ], output_names=[ "nested_imagetype_list", "outputAverageImageName_list", "image_type_list", "nested_interpolation_type", ], ), run_without_submitting=True, name="99_RenestDeformedPassiveImages", ) TemplateBuildSingleIterationWF.connect( inputSpec, "interpolationMapping", RenestDeformedPassiveImagesNode, "interpolationMapping", ) TemplateBuildSingleIterationWF.connect( wimtPassivedeformed, "output_image", RenestDeformedPassiveImagesNode, "deformedPassiveImages", ) TemplateBuildSingleIterationWF.connect( FlattenTransformAndImagesListNode, "flattened_image_nametypes", RenestDeformedPassiveImagesNode, "flattened_image_nametypes", ) ## Now Average All passive input_images deformed images together to create an updated template average AvgDeformedPassiveImages = pe.MapNode( interface=AverageImages(), iterfield=["images", "output_average_image"], name="AvgDeformedPassiveImages", ) AvgDeformedPassiveImages.inputs.dimension = 3 AvgDeformedPassiveImages.inputs.normalize = False TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, "nested_imagetype_list", AvgDeformedPassiveImages, "images", ) TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, "outputAverageImageName_list", AvgDeformedPassiveImages, "output_average_image", ) ## -- INFO: Now neeed to reshape all the passive images as well ReshapeAveragePassiveImageWithShapeUpdate = pe.MapNode( interface=ApplyTransforms(), iterfield=["input_image", "reference_image", "output_image", "interpolation"], name="ReshapeAveragePassiveImageWithShapeUpdate", ) ReshapeAveragePassiveImageWithShapeUpdate.inputs.invert_transform_flags = [ True, False, False, False, False, ] ReshapeAveragePassiveImageWithShapeUpdate.default_value = 0 TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, "nested_interpolation_type", ReshapeAveragePassiveImageWithShapeUpdate, "interpolation", ) TemplateBuildSingleIterationWF.connect( RenestDeformedPassiveImagesNode, "outputAverageImageName_list", ReshapeAveragePassiveImageWithShapeUpdate, "output_image", ) TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, "output_average_image", ReshapeAveragePassiveImageWithShapeUpdate, "input_image", ) TemplateBuildSingleIterationWF.connect( AvgDeformedPassiveImages, "output_average_image", ReshapeAveragePassiveImageWithShapeUpdate, "reference_image", ) TemplateBuildSingleIterationWF.connect( ApplyInvAverageAndFourTimesGradientStepWarpImage, "TransformListWithGradientWarps", ReshapeAveragePassiveImageWithShapeUpdate, "transforms", ) TemplateBuildSingleIterationWF.connect( ReshapeAveragePassiveImageWithShapeUpdate, "output_image", outputSpec, "passive_deformed_templates", ) return TemplateBuildSingleIterationWF
def init_atropos_wf( name="atropos_wf", use_random_seed=True, omp_nthreads=None, mem_gb=3.0, padding=10, in_segmentation_model=tuple(ATROPOS_MODELS["T1w"].values()), ): """ Create an ANTs' ATROPOS workflow for brain tissue segmentation. Implements supersteps 6 and 7 of ``antsBrainExtraction.sh``, which refine the mask previously computed with the spatial normalization to the template. Workflow Graph .. workflow:: :graph2use: orig :simple_form: yes from niworkflows.anat.ants import init_atropos_wf wf = init_atropos_wf() Parameters ---------- use_random_seed : bool Whether ATROPOS should generate a random seed based on the system's clock omp_nthreads : int Maximum number of threads an individual process may use mem_gb : float Estimated peak memory consumption of the most hungry nodes in the workflow padding : int Pad images with zeros before processing in_segmentation_model : tuple A k-means segmentation is run to find gray or white matter around the edge of the initial brain mask warped from the template. This produces a segmentation image with :math:`$K$` classes, ordered by mean intensity in increasing order. With this option, you can control :math:`$K$` and tell the script which classes represent CSF, gray and white matter. Format (K, csfLabel, gmLabel, wmLabel). Examples: ``(3,1,2,3)`` for T1 with K=3, CSF=1, GM=2, WM=3 (default), ``(3,3,2,1)`` for T2 with K=3, CSF=3, GM=2, WM=1, ``(3,1,3,2)`` for FLAIR with K=3, CSF=1 GM=3, WM=2, ``(4,4,2,3)`` uses K=4, CSF=4, GM=2, WM=3. name : str, optional Workflow name (default: "atropos_wf"). Inputs ------ in_files : list :abbr:`INU (intensity non-uniformity)`-corrected files. in_mask : str Brain mask calculated previously. Outputs ------- out_mask : str Refined brain mask out_segm : str Output segmentation out_tpms : str Output :abbr:`TPMs (tissue probability maps)` """ wf = pe.Workflow(name) inputnode = pe.Node( niu.IdentityInterface( fields=["in_files", "in_mask", "in_mask_dilated"]), name="inputnode", ) outputnode = pe.Node( niu.IdentityInterface(fields=["out_mask", "out_segm", "out_tpms"]), name="outputnode", ) copy_xform = pe.Node( CopyXForm(fields=["out_mask", "out_segm", "out_tpms"]), name="copy_xform", run_without_submitting=True, ) # Run atropos (core node) atropos = pe.Node( Atropos( dimension=3, initialization="KMeans", number_of_tissue_classes=in_segmentation_model[0], n_iterations=3, convergence_threshold=0.0, mrf_radius=[1, 1, 1], mrf_smoothing_factor=0.1, likelihood_model="Gaussian", use_random_seed=use_random_seed, ), name="01_atropos", n_procs=omp_nthreads, mem_gb=mem_gb, ) # massage outputs pad_segm = pe.Node(ImageMath(operation="PadImage", op2="%d" % padding), name="02_pad_segm") pad_mask = pe.Node(ImageMath(operation="PadImage", op2="%d" % padding), name="03_pad_mask") # Split segmentation in binary masks sel_labels = pe.Node( niu.Function(function=_select_labels, output_names=["out_wm", "out_gm", "out_csf"]), name="04_sel_labels", ) sel_labels.inputs.labels = list(reversed(in_segmentation_model[1:])) # Select largest components (GM, WM) # ImageMath ${DIMENSION} ${EXTRACTION_WM} GetLargestComponent ${EXTRACTION_WM} get_wm = pe.Node(ImageMath(operation="GetLargestComponent"), name="05_get_wm") get_gm = pe.Node(ImageMath(operation="GetLargestComponent"), name="06_get_gm") # Fill holes and calculate intersection # ImageMath ${DIMENSION} ${EXTRACTION_TMP} FillHoles ${EXTRACTION_GM} 2 # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${EXTRACTION_TMP} ${EXTRACTION_GM} fill_gm = pe.Node(ImageMath(operation="FillHoles", op2="2"), name="07_fill_gm") mult_gm = pe.Node( MultiplyImages(dimension=3, output_product_image="08_mult_gm.nii.gz"), name="08_mult_gm", ) # MultiplyImages ${DIMENSION} ${EXTRACTION_WM} ${ATROPOS_WM_CLASS_LABEL} ${EXTRACTION_WM} # ImageMath ${DIMENSION} ${EXTRACTION_TMP} ME ${EXTRACTION_CSF} 10 relabel_wm = pe.Node( MultiplyImages( dimension=3, second_input=in_segmentation_model[-1], output_product_image="09_relabel_wm.nii.gz", ), name="09_relabel_wm", ) me_csf = pe.Node(ImageMath(operation="ME", op2="10"), name="10_me_csf") # ImageMath ${DIMENSION} ${EXTRACTION_GM} addtozero ${EXTRACTION_GM} ${EXTRACTION_TMP} # MultiplyImages ${DIMENSION} ${EXTRACTION_GM} ${ATROPOS_GM_CLASS_LABEL} ${EXTRACTION_GM} # ImageMath ${DIMENSION} ${EXTRACTION_SEGMENTATION} addtozero ${EXTRACTION_WM} ${EXTRACTION_GM} add_gm = pe.Node(ImageMath(operation="addtozero"), name="11_add_gm") relabel_gm = pe.Node( MultiplyImages( dimension=3, second_input=in_segmentation_model[-2], output_product_image="12_relabel_gm.nii.gz", ), name="12_relabel_gm", ) add_gm_wm = pe.Node(ImageMath(operation="addtozero"), name="13_add_gm_wm") # Superstep 7 # Split segmentation in binary masks sel_labels2 = pe.Node( niu.Function(function=_select_labels, output_names=["out_gm", "out_wm"]), name="14_sel_labels2", ) sel_labels2.inputs.labels = in_segmentation_model[2:] # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} ${EXTRACTION_TMP} add_7 = pe.Node(ImageMath(operation="addtozero"), name="15_add_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 2 me_7 = pe.Node(ImageMath(operation="ME", op2="2"), name="16_me_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} GetLargestComponent ${EXTRACTION_MASK} comp_7 = pe.Node(ImageMath(operation="GetLargestComponent"), name="17_comp_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 4 md_7 = pe.Node(ImageMath(operation="MD", op2="4"), name="18_md_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} FillHoles ${EXTRACTION_MASK} 2 fill_7 = pe.Node(ImageMath(operation="FillHoles", op2="2"), name="19_fill_7") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} addtozero ${EXTRACTION_MASK} \ # ${EXTRACTION_MASK_PRIOR_WARPED} add_7_2 = pe.Node(ImageMath(operation="addtozero"), name="20_add_7_2") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} MD ${EXTRACTION_MASK} 5 md_7_2 = pe.Node(ImageMath(operation="MD", op2="5"), name="21_md_7_2") # ImageMath ${DIMENSION} ${EXTRACTION_MASK} ME ${EXTRACTION_MASK} 5 me_7_2 = pe.Node(ImageMath(operation="ME", op2="5"), name="22_me_7_2") # De-pad depad_mask = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="23_depad_mask") depad_segm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="24_depad_segm") depad_gm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="25_depad_gm") depad_wm = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="26_depad_wm") depad_csf = pe.Node(ImageMath(operation="PadImage", op2="-%d" % padding), name="27_depad_csf") msk_conform = pe.Node(niu.Function(function=_conform_mask), name="msk_conform") merge_tpms = pe.Node(niu.Merge(in_segmentation_model[0]), name="merge_tpms") # fmt: off wf.connect([ (inputnode, copy_xform, [(("in_files", _pop), "hdr_file")]), (inputnode, pad_mask, [("in_mask", "op1")]), (inputnode, atropos, [ ("in_files", "intensity_images"), ("in_mask_dilated", "mask_image"), ]), (inputnode, msk_conform, [(("in_files", _pop), "in_reference")]), (atropos, pad_segm, [("classified_image", "op1")]), (pad_segm, sel_labels, [("output_image", "in_segm")]), (sel_labels, get_wm, [("out_wm", "op1")]), (sel_labels, get_gm, [("out_gm", "op1")]), (get_gm, fill_gm, [("output_image", "op1")]), (get_gm, mult_gm, [("output_image", "first_input")]), (fill_gm, mult_gm, [("output_image", "second_input")]), (get_wm, relabel_wm, [("output_image", "first_input")]), (sel_labels, me_csf, [("out_csf", "op1")]), (mult_gm, add_gm, [("output_product_image", "op1")]), (me_csf, add_gm, [("output_image", "op2")]), (add_gm, relabel_gm, [("output_image", "first_input")]), (relabel_wm, add_gm_wm, [("output_product_image", "op1")]), (relabel_gm, add_gm_wm, [("output_product_image", "op2")]), (add_gm_wm, sel_labels2, [("output_image", "in_segm")]), (sel_labels2, add_7, [("out_wm", "op1"), ("out_gm", "op2")]), (add_7, me_7, [("output_image", "op1")]), (me_7, comp_7, [("output_image", "op1")]), (comp_7, md_7, [("output_image", "op1")]), (md_7, fill_7, [("output_image", "op1")]), (fill_7, add_7_2, [("output_image", "op1")]), (pad_mask, add_7_2, [("output_image", "op2")]), (add_7_2, md_7_2, [("output_image", "op1")]), (md_7_2, me_7_2, [("output_image", "op1")]), (me_7_2, depad_mask, [("output_image", "op1")]), (add_gm_wm, depad_segm, [("output_image", "op1")]), (relabel_wm, depad_wm, [("output_product_image", "op1")]), (relabel_gm, depad_gm, [("output_product_image", "op1")]), (sel_labels, depad_csf, [("out_csf", "op1")]), (depad_csf, merge_tpms, [("output_image", "in1")]), (depad_gm, merge_tpms, [("output_image", "in2")]), (depad_wm, merge_tpms, [("output_image", "in3")]), (depad_mask, msk_conform, [("output_image", "in_mask")]), (msk_conform, copy_xform, [("out", "out_mask")]), (depad_segm, copy_xform, [("output_image", "out_segm")]), (merge_tpms, copy_xform, [("out", "out_tpms")]), (copy_xform, outputnode, [ ("out_mask", "out_mask"), ("out_segm", "out_segm"), ("out_tpms", "out_tpms"), ]), ]) # fmt: on return wf
def ANTs_cortical_thickness(subject_list, directory): #============================================================== # Loading required packages import nipype.interfaces.io as nio import nipype.pipeline.engine as pe import nipype.interfaces.utility as util import own_nipype from nipype.interfaces.ants.segmentation import antsCorticalThickness from nipype.interfaces.ants import ApplyTransforms from nipype.interfaces.ants import MultiplyImages from nipype.interfaces.utility import Function from nipype.interfaces.ants.visualization import ConvertScalarImageToRGB from nipype.interfaces.ants.visualization import CreateTiledMosaic from nipype.interfaces.utility import Select from own_nipype import GM_DENSITY from nipype import SelectFiles import os #==================================== # Defining the nodes for the workflow # Getting the subject ID infosource = pe.Node( interface=util.IdentityInterface(fields=['subject_id']), name='infosource') infosource.iterables = ('subject_id', subject_list) # Getting the relevant diffusion-weighted data templates = dict( T1= '/imaging/jb07/CALM/CALM_BIDS/{subject_id}/anat/{subject_id}_T1w.nii.gz' ) selectfiles = pe.Node(SelectFiles(templates), name="selectfiles") selectfiles.inputs.base_directory = os.path.abspath(directory) # Rigid alignment with the template space T1_rigid_quickSyN = pe.Node(interface=own_nipype.ants_QuickSyN( image_dimensions=3, transform_type='r'), name='T1_rigid_quickSyN') T1_rigid_quickSyN.inputs.fixed_image = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0.nii.gz' # Cortical thickness calculation corticalthickness = pe.Node(interface=antsCorticalThickness(), name='corticalthickness') corticalthickness.inputs.brain_probability_mask = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0_BrainCerebellumProbabilityMask.nii.gz' corticalthickness.inputs.brain_template = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0.nii.gz' corticalthickness.inputs.segmentation_priors = [ '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors1.nii.gz', '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors2.nii.gz', '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors3.nii.gz', '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors4.nii.gz', '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors5.nii.gz', '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/Priors2/priors6.nii.gz' ] corticalthickness.inputs.extraction_registration_mask = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0_BrainCerebellumExtractionMask.nii.gz' corticalthickness.inputs.t1_registration_template = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0_BrainCerebellum.nii.gz' # Creating visualisations for quality control converter = pe.Node(interface=ConvertScalarImageToRGB(), name='converter') converter.inputs.dimension = 3 converter.inputs.colormap = 'cool' converter.inputs.minimum_input = 0 converter.inputs.maximum_input = 5 mosaic_slicer = pe.Node(interface=CreateTiledMosaic(), name='mosaic_slicer') mosaic_slicer.inputs.pad_or_crop = 'mask' mosaic_slicer.inputs.slices = '[4 ,mask , mask]' mosaic_slicer.inputs.direction = 1 mosaic_slicer.inputs.alpha_value = 0.5 # Getting GM density images gm_density = pe.Node(interface=GM_DENSITY(), name='gm_density') sl = pe.Node(interface=Select(index=1), name='sl') # Applying transformation at = pe.Node(interface=ApplyTransforms(), name='at') at.inputs.dimension = 3 at.inputs.reference_image = '/imaging/jb07/Atlases/OASIS/OASIS-30_Atropos_template/T_template0_BrainCerebellum.nii.gz' at.inputs.interpolation = 'Linear' at.inputs.default_value = 0 at.inputs.invert_transform_flags = False # Multiplying the normalized image with Jacobian multiply_images = pe.Node(interface=MultiplyImages(dimension=3), name='multiply_images') # Naming the output of multiply_image def generate_filename(subject_id): return subject_id + '_multiplied.nii.gz' generate_filename = pe.Node(interface=Function( input_names=["subject_id"], output_names=["out_filename"], function=generate_filename), name='generate_filename') #==================================== # Setting up the workflow antsthickness = pe.Workflow(name='antsthickness') antsthickness.connect(infosource, 'subject_id', selectfiles, 'subject_id') antsthickness.connect(selectfiles, 'T1', T1_rigid_quickSyN, 'moving_image') antsthickness.connect(infosource, 'subject_id', T1_rigid_quickSyN, 'output_prefix') antsthickness.connect(T1_rigid_quickSyN, 'warped_image', corticalthickness, 'anatomical_image') antsthickness.connect(infosource, 'subject_id', corticalthickness, 'out_prefix') antsthickness.connect(corticalthickness, 'CorticalThickness', converter, 'input_image') antsthickness.connect(converter, 'output_image', mosaic_slicer, 'rgb_image') antsthickness.connect(corticalthickness, 'BrainSegmentationN4', mosaic_slicer, 'input_image') antsthickness.connect(corticalthickness, 'BrainExtractionMask', mosaic_slicer, 'mask_image') antsthickness.connect(corticalthickness, 'BrainSegmentationN4', gm_density, 'in_file') antsthickness.connect(corticalthickness, 'BrainSegmentationPosteriors', sl, 'inlist') antsthickness.connect(sl, 'out', gm_density, 'mask_file') antsthickness.connect(corticalthickness, 'SubjectToTemplate1Warp', at, 'transforms') antsthickness.connect(gm_density, 'out_file', at, 'input_image') antsthickness.connect(corticalthickness, 'SubjectToTemplateLogJacobian', multiply_images, 'second_input') antsthickness.connect(corticalthickness, 'CorticalThicknessNormedToTemplate', multiply_images, 'first_input') antsthickness.connect(infosource, 'subject_id', generate_filename, 'subject_id') antsthickness.connect(generate_filename, 'out_filename', multiply_images, 'output_product_image') #==================================== # Running the workflow antsthickness.base_dir = os.path.abspath(directory) antsthickness.write_graph() antsthickness.run('PBSGraph')