コード例 #1
0
def create_mrtrix_tracking_flow(config):
    """Create the tractography sub-workflow of the `DiffusionStage` using MRtrix3.

    Parameters
    ----------
    config : MRtrix_tracking_config
        Sub-workflow configuration object

    Returns
    -------
    flow : nipype.pipeline.engine.Workflow
        Built tractography sub-workflow
    """
    flow = pe.Workflow(name="tracking")
    # inputnode
    inputnode = pe.Node(interface=util.IdentityInterface(fields=[
        'DWI', 'wm_mask_resampled', 'gm_registered', 'act_5tt_registered',
        'gmwmi_registered', 'grad'
    ]),
                        name='inputnode')

    # outputnode
    outputnode = pe.Node(
        interface=util.IdentityInterface(fields=["track_file"]),
        name='outputnode')

    # Compute single fiber voxel mask
    wm_erode = pe.Node(
        interface=Erode(out_filename="wm_mask_resampled.nii.gz"),
        name='wm_erode')
    wm_erode.inputs.number_of_passes = 3
    wm_erode.inputs.filtertype = 'erode'

    flow.connect([(inputnode, wm_erode, [("wm_mask_resampled", 'in_file')])])

    if config.tracking_mode == 'Deterministic':
        mrtrix_seeds = pe.Node(interface=Make_Mrtrix_Seeds(),
                               name='mrtrix_seeds')
        mrtrix_tracking = pe.Node(interface=StreamlineTrack(),
                                  name='mrtrix_deterministic_tracking')
        mrtrix_tracking.inputs.desired_number_of_tracks = config.desired_number_of_tracks
        # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds
        mrtrix_tracking.inputs.maximum_tract_length = config.max_length
        mrtrix_tracking.inputs.minimum_tract_length = config.min_length
        mrtrix_tracking.inputs.step_size = config.step_size
        mrtrix_tracking.inputs.angle = config.angle
        mrtrix_tracking.inputs.cutoff_value = config.cutoff_value

        # mrtrix_tracking.inputs.args = '2>/dev/null'
        if config.curvature >= 0.000001:
            mrtrix_tracking.inputs.rk4 = True
            mrtrix_tracking.inputs.inputmodel = 'SD_Stream'
        else:
            mrtrix_tracking.inputs.inputmodel = 'SD_Stream'
        flow.connect([(inputnode, mrtrix_tracking,
                       [("grad", "gradient_encoding_file")])])

        voxel2WorldMatrixExtracter = pe.Node(
            interface=ExtractHeaderVoxel2WorldMatrix(),
            name='voxel2WorldMatrixExtracter')

        flow.connect([(inputnode, voxel2WorldMatrixExtracter,
                       [("wm_mask_resampled", "in_file")])])

        flow.connect([
            (inputnode, mrtrix_seeds, [('wm_mask_resampled', 'WM_file')]),
            (inputnode, mrtrix_seeds, [('gm_registered', 'ROI_files')]),
        ])

        if config.use_act:
            flow.connect([
                (inputnode, mrtrix_tracking, [('act_5tt_registered',
                                               'act_file')]),
            ])
            mrtrix_tracking.inputs.backtrack = config.backtrack
            mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi
        else:
            flow.connect([
                (inputnode, mrtrix_tracking, [('wm_mask_resampled',
                                               'mask_file')]),
            ])

        if config.seed_from_gmwmi:
            flow.connect([
                (inputnode, mrtrix_tracking, [('gmwmi_registered',
                                               'seed_gmwmi')]),
            ])
        else:
            flow.connect([
                (inputnode, mrtrix_tracking, [('wm_mask_resampled',
                                               'seed_file')]),
            ])

        # converter = pe.Node(interface=mrtrix.MRTrix2TrackVis(),name="trackvis")
        converter = pe.Node(interface=Tck2Trk(), name='trackvis')
        converter.inputs.out_tracks = 'converted.trk'

        if config.sift:

            filter_tractogram = pe.Node(interface=FilterTractogram(),
                                        name='sift_node')
            filter_tractogram.inputs.out_file = 'sift-filtered_tractogram.tck'

            flow.connect([(mrtrix_tracking, filter_tractogram,
                           [('tracked', 'in_tracks')]),
                          (inputnode, filter_tractogram, [('DWI', 'in_fod')])])

            if config.use_act:
                flow.connect([
                    (inputnode, filter_tractogram, [('act_5tt_registered',
                                                     'act_file')]),
                ])

            flow.connect([(filter_tractogram, converter, [('out_tracks',
                                                           'in_tracks')])])
        else:
            flow.connect([
                (mrtrix_tracking, converter, [('tracked', 'in_tracks')]),
            ])

        flow.connect([(inputnode, mrtrix_tracking, [('DWI', 'in_file')]),
                      (inputnode, converter, [('wm_mask_resampled', 'in_image')
                                              ]),
                      (converter, outputnode, [('out_tracks', 'track_file')])])

    elif config.tracking_mode == 'Probabilistic':
        mrtrix_seeds = pe.Node(interface=Make_Mrtrix_Seeds(),
                               name='mrtrix_seeds')
        mrtrix_tracking = pe.Node(interface=StreamlineTrack(),
                                  name='mrtrix_probabilistic_tracking')
        mrtrix_tracking.inputs.desired_number_of_tracks = config.desired_number_of_tracks
        # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds
        mrtrix_tracking.inputs.maximum_tract_length = config.max_length
        mrtrix_tracking.inputs.minimum_tract_length = config.min_length
        mrtrix_tracking.inputs.step_size = config.step_size
        mrtrix_tracking.inputs.angle = config.angle
        mrtrix_tracking.inputs.cutoff_value = config.cutoff_value
        # mrtrix_tracking.inputs.args = '2>/dev/null'
        # if config.curvature >= 0.000001:
        #    mrtrix_tracking.inputs.rk4 = True
        if config.SD:
            mrtrix_tracking.inputs.inputmodel = 'iFOD2'
        else:
            mrtrix_tracking.inputs.inputmodel = 'Tensor_Prob'

        # converter = pe.MapNode(interface=mrtrix.MRTrix2TrackVis(),iterfield=['in_file'],name='trackvis')
        converter = pe.Node(interface=Tck2Trk(), name='trackvis')
        converter.inputs.out_tracks = 'converted.trk'

        # orientation_matcher = pe.Node(interface=match_orientation(), name="orient_matcher")

        flow.connect([
            (inputnode, mrtrix_seeds, [('wm_mask_resampled', 'WM_file')]),
            (inputnode, mrtrix_seeds, [('gm_registered', 'ROI_files')]),
        ])

        if config.use_act:
            flow.connect([
                (inputnode, mrtrix_tracking, [('act_5tt_registered',
                                               'act_file')]),
            ])
            mrtrix_tracking.inputs.backtrack = config.backtrack
            mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi
        else:
            flow.connect([
                (inputnode, mrtrix_tracking, [('wm_mask_resampled',
                                               'mask_file')]),
            ])

        if config.seed_from_gmwmi:
            flow.connect([
                (inputnode, mrtrix_tracking, [('gmwmi_registered',
                                               'seed_gmwmi')]),
            ])
        else:
            flow.connect([
                (inputnode, mrtrix_tracking, [('wm_mask_resampled',
                                               'seed_file')]),
            ])

        if config.sift:

            filter_tractogram = pe.Node(interface=FilterTractogram(),
                                        name='sift_node')
            filter_tractogram.inputs.out_file = 'sift-filtered_tractogram.tck'

            flow.connect([(mrtrix_tracking, filter_tractogram,
                           [('tracked', 'in_tracks')]),
                          (inputnode, filter_tractogram, [('DWI', 'in_fod')])])

            if config.use_act:
                flow.connect([
                    (inputnode, filter_tractogram, [('act_5tt_registered',
                                                     'act_file')]),
                ])

            flow.connect([(filter_tractogram, converter, [('out_tracks',
                                                           'in_tracks')])])
        else:
            flow.connect([
                (mrtrix_tracking, converter, [('tracked', 'in_tracks')]),
            ])

        flow.connect([(inputnode, mrtrix_tracking, [('DWI', 'in_file')]),
                      (inputnode, converter, [('wm_mask_resampled', 'in_image')
                                              ]),
                      (converter, outputnode, [('out_tracks', 'track_file')])])

    return flow
コード例 #2
0
def create_mrtrix_tracking_flow(config):
    """Create the tractography sub-workflow of the `DiffusionStage` using MRtrix3.

    Parameters
    ----------
    config : MRtrixTrackingConfig
        Sub-workflow configuration object

    Returns
    -------
    flow : nipype.pipeline.engine.Workflow
        Built tractography sub-workflow
    """
    flow = pe.Workflow(name="tracking")
    # inputnode
    inputnode = pe.Node(
        interface=util.IdentityInterface(fields=[
            "DWI",
            "wm_mask_resampled",
            "gm_registered",
            "act_5tt_registered",
            "gmwmi_registered",
            "grad",
        ]),
        name="inputnode",
    )

    # outputnode
    outputnode = pe.Node(
        interface=util.IdentityInterface(fields=["track_file"]),
        name="outputnode")

    # Compute single fiber voxel mask
    wm_erode = pe.Node(
        interface=Erode(out_filename="wm_mask_resampled.nii.gz"),
        name="wm_erode")
    wm_erode.inputs.number_of_passes = 3
    wm_erode.inputs.filtertype = "erode"

    flow.connect([(inputnode, wm_erode, [("wm_mask_resampled", "in_file")])])

    if config.tracking_mode == "Deterministic":
        mrtrix_tracking = pe.Node(interface=StreamlineTrack(),
                                  name="mrtrix_deterministic_tracking")
        mrtrix_tracking.inputs.desired_number_of_tracks = (
            config.desired_number_of_tracks)
        # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds
        mrtrix_tracking.inputs.maximum_tract_length = config.max_length
        mrtrix_tracking.inputs.minimum_tract_length = config.min_length
        mrtrix_tracking.inputs.step_size = config.step_size
        mrtrix_tracking.inputs.angle = config.angle
        mrtrix_tracking.inputs.cutoff_value = config.cutoff_value

        # mrtrix_tracking.inputs.args = '2>/dev/null'
        if config.curvature >= 0.000001:
            mrtrix_tracking.inputs.rk4 = True
            mrtrix_tracking.inputs.inputmodel = "SD_Stream"
        else:
            mrtrix_tracking.inputs.inputmodel = "SD_Stream"
        # fmt:off
        flow.connect([(inputnode, mrtrix_tracking,
                       [("grad", "gradient_encoding_file")])])
        # fmt:on

        voxel2WorldMatrixExtracter = pe.Node(
            interface=ExtractHeaderVoxel2WorldMatrix(),
            name="voxel2WorldMatrixExtracter",
        )

        # fmt:off
        flow.connect([(
            inputnode,
            voxel2WorldMatrixExtracter,
            [("wm_mask_resampled", "in_file")],
        )])
        # fmt:on

        if config.use_act:
            # fmt:off
            flow.connect([
                (inputnode, mrtrix_tracking, [("act_5tt_registered",
                                               "act_file")]),
            ])
            # fmt:on
            mrtrix_tracking.inputs.backtrack = config.backtrack
            mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi
        else:
            # fmt:off
            flow.connect([
                (inputnode, mrtrix_tracking, [("wm_mask_resampled",
                                               "mask_file")]),
            ])
            # fmt:on

        if config.seed_from_gmwmi:
            # fmt:off
            flow.connect([
                (inputnode, mrtrix_tracking, [("gmwmi_registered",
                                               "seed_gmwmi")]),
            ])
            # fmt:on
        else:
            # fmt:off
            flow.connect([
                (inputnode, mrtrix_tracking, [("wm_mask_resampled",
                                               "seed_file")]),
            ])
            # fmt:on

        # converter = pe.Node(interface=mrtrix.MRTrix2TrackVis(),name="trackvis")
        converter = pe.Node(interface=Tck2Trk(), name="trackvis")
        converter.inputs.out_tracks = "converted.trk"

        if config.sift:

            filter_tractogram = pe.Node(interface=FilterTractogram(),
                                        name="sift_node")
            filter_tractogram.inputs.out_file = "sift-filtered_tractogram.tck"
            # fmt:off
            flow.connect([
                (mrtrix_tracking, filter_tractogram, [("tracked", "in_tracks")
                                                      ]),
                (inputnode, filter_tractogram, [("DWI", "in_fod")]),
            ])
            # fmt:on
            if config.use_act:
                # fmt:off
                flow.connect([
                    (
                        inputnode,
                        filter_tractogram,
                        [("act_5tt_registered", "act_file")],
                    ),
                ])
                # fmt:on
            # fmt:off
            flow.connect([(filter_tractogram, converter, [("out_tracks",
                                                           "in_tracks")])])
            # fmt:on
        else:
            # fmt:off
            flow.connect([
                (mrtrix_tracking, converter, [("tracked", "in_tracks")]),
            ])
            # fmt:on
        # fmt:off
        flow.connect([
            (inputnode, mrtrix_tracking, [("DWI", "in_file")]),
            (inputnode, converter, [("wm_mask_resampled", "in_image")]),
            (converter, outputnode, [("out_tracks", "track_file")]),
        ])
        # fmt:on

    elif config.tracking_mode == "Probabilistic":
        mrtrix_tracking = pe.Node(interface=StreamlineTrack(),
                                  name="mrtrix_probabilistic_tracking")
        mrtrix_tracking.inputs.desired_number_of_tracks = (
            config.desired_number_of_tracks)
        # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds
        mrtrix_tracking.inputs.maximum_tract_length = config.max_length
        mrtrix_tracking.inputs.minimum_tract_length = config.min_length
        mrtrix_tracking.inputs.step_size = config.step_size
        mrtrix_tracking.inputs.angle = config.angle
        mrtrix_tracking.inputs.cutoff_value = config.cutoff_value
        # mrtrix_tracking.inputs.args = '2>/dev/null'
        # if config.curvature >= 0.000001:
        #    mrtrix_tracking.inputs.rk4 = True
        if config.SD:
            mrtrix_tracking.inputs.inputmodel = "iFOD2"
        else:
            mrtrix_tracking.inputs.inputmodel = "Tensor_Prob"

        converter = pe.Node(interface=Tck2Trk(), name="trackvis")
        converter.inputs.out_tracks = "converted.trk"

        if config.use_act:
            # fmt:off
            flow.connect([
                (inputnode, mrtrix_tracking, [("act_5tt_registered",
                                               "act_file")]),
            ])
            # fmt:on
            mrtrix_tracking.inputs.backtrack = config.backtrack
            mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi
        else:
            # fmt:off
            flow.connect([
                (inputnode, mrtrix_tracking, [("wm_mask_resampled",
                                               "mask_file")]),
            ])
            # fmt:on

        if config.seed_from_gmwmi:
            # fmt:off
            flow.connect([
                (inputnode, mrtrix_tracking, [("gmwmi_registered",
                                               "seed_gmwmi")]),
            ])
            # fmt:on
        else:
            # fmt:off
            flow.connect([
                (inputnode, mrtrix_tracking, [("wm_mask_resampled",
                                               "seed_file")]),
            ])
            # fmt:on

        if config.sift:

            filter_tractogram = pe.Node(interface=FilterTractogram(),
                                        name="sift_node")
            filter_tractogram.inputs.out_file = "sift-filtered_tractogram.tck"
            # fmt:off
            flow.connect([
                (mrtrix_tracking, filter_tractogram, [("tracked", "in_tracks")
                                                      ]),
                (inputnode, filter_tractogram, [("DWI", "in_fod")]),
            ])
            # fmt:on
            if config.use_act:
                # fmt:off
                flow.connect([
                    (
                        inputnode,
                        filter_tractogram,
                        [("act_5tt_registered", "act_file")],
                    ),
                ])
                # fmt:on
            # fmt:off
            flow.connect([(filter_tractogram, converter, [("out_tracks",
                                                           "in_tracks")])])
            # fmt:on
        else:
            # fmt:off
            flow.connect([
                (mrtrix_tracking, converter, [("tracked", "in_tracks")]),
            ])
            # fmt:on

        # fmt:off
        flow.connect([
            (inputnode, mrtrix_tracking, [("DWI", "in_file")]),
            (inputnode, converter, [("wm_mask_resampled", "in_image")]),
            (converter, outputnode, [("out_tracks", "track_file")]),
        ])
        # fmt:on

    return flow
コード例 #3
0
ファイル: tracking.py プロジェクト: rcruces/connectomemapper3
def create_mrtrix_tracking_flow(config):
    flow = pe.Workflow(name="tracking")
    # inputnode
    inputnode = pe.Node(interface=util.IdentityInterface(fields=[
        'DWI', 'wm_mask_resampled', 'gm_registered', 'act_5tt_registered',
        'gmwmi_registered', 'grad'
    ]),
                        name='inputnode')
    # outputnode

    # CRS2XYZtkReg = subprocess.check_output

    outputnode = pe.Node(
        interface=util.IdentityInterface(fields=["track_file"]),
        name='outputnode')

    # Compute single fiber voxel mask
    wm_erode = pe.Node(
        interface=Erode(out_filename="wm_mask_resampled.nii.gz"),
        name='wm_erode')
    wm_erode.inputs.number_of_passes = 3
    wm_erode.inputs.filtertype = 'erode'

    flow.connect([(inputnode, wm_erode, [("wm_mask_resampled", 'in_file')])])

    if config.tracking_mode == 'Deterministic':
        mrtrix_seeds = pe.Node(interface=make_mrtrix_seeds(),
                               name='mrtrix_seeds')
        mrtrix_tracking = pe.Node(interface=StreamlineTrack(),
                                  name='mrtrix_deterministic_tracking')
        mrtrix_tracking.inputs.desired_number_of_tracks = config.desired_number_of_tracks
        # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds
        mrtrix_tracking.inputs.maximum_tract_length = config.max_length
        mrtrix_tracking.inputs.minimum_tract_length = config.min_length
        mrtrix_tracking.inputs.step_size = config.step_size
        mrtrix_tracking.inputs.angle = config.angle
        mrtrix_tracking.inputs.cutoff_value = config.cutoff_value

        # mrtrix_tracking.inputs.args = '2>/dev/null'
        if config.curvature >= 0.000001:
            mrtrix_tracking.inputs.rk4 = True
            mrtrix_tracking.inputs.inputmodel = 'SD_Stream'
        else:
            mrtrix_tracking.inputs.inputmodel = 'SD_Stream'
        flow.connect([(inputnode, mrtrix_tracking,
                       [("grad", "gradient_encoding_file")])])

        voxel2WorldMatrixExtracter = pe.Node(
            interface=extractHeaderVoxel2WorldMatrix(),
            name='voxel2WorldMatrixExtracter')

        flow.connect([(inputnode, voxel2WorldMatrixExtracter,
                       [("wm_mask_resampled", "in_file")])])
        # transform_trackvisdata = pe.Node(interface=transform_trk_CRS2XYZtkReg(),name='transform_trackvisdata')
        # flow.connect([
        #             (converter,transform_trackvisdata,[('out_file','trackvis_file')]),
        #             (inputnode,transform_trackvisdata,[('wm_mask_resampled','ref_image_file')])
        #             ])

        # orientation_matcher = pe.Node(
        #     interface=match_orientations(), name='orient_matcher')

        flow.connect([
            (inputnode, mrtrix_seeds, [('wm_mask_resampled', 'WM_file')]),
            (inputnode, mrtrix_seeds, [('gm_registered', 'ROI_files')]),
        ])

        if config.use_act:
            flow.connect([
                (inputnode, mrtrix_tracking, [('act_5tt_registered',
                                               'act_file')]),
            ])
            mrtrix_tracking.inputs.backtrack = config.backtrack
            mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi
        else:
            flow.connect([
                (inputnode, mrtrix_tracking, [('wm_mask_resampled',
                                               'mask_file')]),
            ])

        if config.seed_from_gmwmi:
            flow.connect([
                (inputnode, mrtrix_tracking, [('gmwmi_registered',
                                               'seed_gmwmi')]),
            ])
        else:
            flow.connect([
                (inputnode, mrtrix_tracking, [('wm_mask_resampled',
                                               'seed_file')]),
            ])

        # converter = pe.Node(interface=mrtrix.MRTrix2TrackVis(),name="trackvis")
        converter = pe.Node(interface=Tck2Trk(), name='trackvis')
        converter.inputs.out_tracks = 'converted.trk'

        flow.connect([
            # (mrtrix_seeds,mrtrix_tracking,[('seed_files','seed_file')]),
            (inputnode, mrtrix_tracking, [('DWI', 'in_file')]),
            # (inputnode,mrtrix_tracking,[('wm_mask_resampled','mask_file')]),
            # (wm_erode, mrtrix_tracking,[('out_file','mask_file')]),
            # (mrtrix_tracking,outputnode,[('tracked','track_file')]),
            # (mrtrix_tracking,converter,[('tracked','in_file')]),
            # (inputnode,converter,[('wm_mask_resampled','image_file')]),
            # (converter,outputnode,[('out_file','track_file')])
            (mrtrix_tracking, converter, [('tracked', 'in_tracks')]),
            (inputnode, converter, [('wm_mask_resampled', 'in_image')]),
            (converter, outputnode, [('out_tracks', 'track_file')])
        ])

        # flow.connect([
        #               (inputnode,mrtrix_tracking,[('DWI','in_file'),('wm_mask_resampled','seed_file'),('wm_mask_resampled','mask_file')]),
        #               (mrtrix_tracking,converter,[('tracked','in_file')]),
        #               (inputnode,converter,[('wm_mask_resampled','image_file')]),
        #               (inputnode,converter,[('wm_mask_resampled','registration_image_file')]),
        #               (voxel2WorldMatrixExtracter,converter,[('out_matrix','matrix_file')]),
        #               # (converter,orientation_matcher,[('out_file','trackvis_file')]),
        #               # (inputnode,orientation_matcher,[('wm_mask_resampled','ref_image_file')]),
        #               # (orientation_matcher,outputnode,[('out_file','track_file')])
        #               (mrtrix_tracking,outputnode,[('tracked','track_file')])
        #               #(converter,outputnode,[('out_file','track_file')])
        #               ])

    elif config.tracking_mode == 'Probabilistic':
        mrtrix_seeds = pe.Node(interface=make_mrtrix_seeds(),
                               name='mrtrix_seeds')
        mrtrix_tracking = pe.Node(interface=StreamlineTrack(),
                                  name='mrtrix_probabilistic_tracking')
        mrtrix_tracking.inputs.desired_number_of_tracks = config.desired_number_of_tracks
        # mrtrix_tracking.inputs.maximum_number_of_seeds = config.max_number_of_seeds
        mrtrix_tracking.inputs.maximum_tract_length = config.max_length
        mrtrix_tracking.inputs.minimum_tract_length = config.min_length
        mrtrix_tracking.inputs.step_size = config.step_size
        mrtrix_tracking.inputs.angle = config.angle
        mrtrix_tracking.inputs.cutoff_value = config.cutoff_value
        # mrtrix_tracking.inputs.args = '2>/dev/null'
        # if config.curvature >= 0.000001:
        #    mrtrix_tracking.inputs.rk4 = True
        if config.SD:
            mrtrix_tracking.inputs.inputmodel = 'iFOD2'
        else:
            mrtrix_tracking.inputs.inputmodel = 'Tensor_Prob'
        # converter = pe.MapNode(interface=mrtrix.MRTrix2TrackVis(),iterfield=['in_file'],name='trackvis')
        converter = pe.Node(interface=Tck2Trk(), name='trackvis')
        converter.inputs.out_tracks = 'converted.trk'
        # orientation_matcher = pe.Node(interface=match_orientation(), name="orient_matcher")

        flow.connect([
            (inputnode, mrtrix_seeds, [('wm_mask_resampled', 'WM_file')]),
            (inputnode, mrtrix_seeds, [('gm_registered', 'ROI_files')]),
        ])

        if config.use_act:
            flow.connect([
                (inputnode, mrtrix_tracking, [('act_5tt_registered',
                                               'act_file')]),
            ])
            mrtrix_tracking.inputs.backtrack = config.backtrack
            mrtrix_tracking.inputs.crop_at_gmwmi = config.crop_at_gmwmi
        else:
            flow.connect([
                (inputnode, mrtrix_tracking, [('wm_mask_resampled',
                                               'mask_file')]),
            ])

        if config.seed_from_gmwmi:
            flow.connect([
                (inputnode, mrtrix_tracking, [('gmwmi_registered',
                                               'seed_gmwmi')]),
            ])
        else:
            flow.connect([
                (inputnode, mrtrix_tracking, [('wm_mask_resampled',
                                               'seed_file')]),
            ])

        flow.connect([
            (inputnode, mrtrix_tracking, [('DWI', 'in_file')]),
            # (inputnode,mrtrix_tracking,[('wm_mask_resampled','mask_file')]),
            # (mrtrix_tracking,outputnode,[('tracked','track_file')]),
            # (mrtrix_tracking,converter,[('tracked','in_file')]),
            # (mrtrix_tracking,converter,[('tracked','in_file')]),
            # (inputnode,converter,[('wm_mask_resampled','image_file')]),
            # # (converter,outputnode,[('out_file','track_file')])
            # (converter,outputnode,[('out_tracks','track_file')])
            # (mrtrix_tracking,converter,[('tracked','in_file')]),
            # (inputnode,converter,[('wm_mask_resampled','image_file')]),
            # (converter,outputnode,[('out_file','track_file')])
            (mrtrix_tracking, converter, [('tracked', 'in_tracks')]),
            (inputnode, converter, [('wm_mask_resampled', 'in_image')]),
            (converter, outputnode, [('out_tracks', 'track_file')])
        ])

    return flow
コード例 #4
0
def create_mrtrix_recon_flow(config):
    '''Create the diffusion reconstruction workflow.

    It estimates the tensors or the fiber orientation distribution functions.

    Parameters
    ----------
    config '''

    # TODO: Add AD and RD maps
    flow = pe.Workflow(name="reconstruction")
    inputnode = pe.Node(
        interface=util.IdentityInterface(
            fields=["diffusion", "diffusion_resampled", "wm_mask_resampled", "grad"]),
        name="inputnode")
    outputnode = pe.Node(interface=util.IdentityInterface(fields=["DWI", "FA", "ADC", "tensor", "eigVec", "RF", "grad"],
                                                          mandatory_inputs=True), name="outputnode")

    # Flip gradient table
    flip_table = pe.Node(interface=flipTable(), name='flip_table')

    flip_table.inputs.flipping_axis = config.flip_table_axis
    flip_table.inputs.delimiter = ' '
    flip_table.inputs.header_lines = 0
    flip_table.inputs.orientation = 'v'
    flow.connect([
        (inputnode, flip_table, [("grad", "table")]),
        (flip_table, outputnode, [("table", "grad")])
    ])
    # flow.connect([
    #             (inputnode,outputnode,[("grad","grad")])
    #             ])

    # Tensor
    mrtrix_tensor = pe.Node(interface=DWI2Tensor(), name='mrtrix_make_tensor')

    flow.connect([
        (inputnode, mrtrix_tensor, [('diffusion_resampled', 'in_file')]),
        (flip_table, mrtrix_tensor, [("table", "encoding_file")]),
    ])

    # Tensor -> FA map
    mrtrix_tensor_metrics = pe.Node(interface=TensorMetrics(out_fa='FA.mif', out_adc='ADC.mif'),
                                    name='mrtrix_tensor_metrics')
    convert_Tensor = pe.Node(interface=MRConvert(
        out_filename="dwi_tensor.nii.gz"), name='convert_tensor')
    convert_FA = pe.Node(interface=MRConvert(
        out_filename="FA.nii.gz"), name='convert_FA')
    convert_ADC = pe.Node(interface=MRConvert(
        out_filename="ADC.nii.gz"), name='convert_ADC')

    flow.connect([
        (mrtrix_tensor, convert_Tensor, [('tensor', 'in_file')]),
        (mrtrix_tensor, mrtrix_tensor_metrics, [('tensor', 'in_file')]),
        (mrtrix_tensor_metrics, convert_FA, [('out_fa', 'in_file')]),
        (mrtrix_tensor_metrics, convert_ADC, [('out_adc', 'in_file')]),
        (convert_Tensor, outputnode, [("converted", "tensor")]),
        (convert_FA, outputnode, [("converted", "FA")]),
        (convert_ADC, outputnode, [("converted", "ADC")])
    ])

    # Tensor -> Eigenvectors
    mrtrix_eigVectors = pe.Node(
        interface=Tensor2Vector(), name='mrtrix_eigenvectors')

    flow.connect([
        (mrtrix_tensor, mrtrix_eigVectors, [('tensor', 'in_file')]),
        (mrtrix_eigVectors, outputnode, [('vector', 'eigVec')])
    ])

    # Constrained Spherical Deconvolution
    if config.local_model:
        print("CSD true")
        # Compute single fiber voxel mask
        mrtrix_erode = pe.Node(interface=Erode(
            out_filename='wm_mask_res_eroded.nii.gz'), name='mrtrix_erode')
        mrtrix_erode.inputs.number_of_passes = 1
        mrtrix_erode.inputs.filtertype = 'erode'
        mrtrix_mul_eroded_FA = pe.Node(
            interface=MRtrix_mul(), name='mrtrix_mul_eroded_FA')
        mrtrix_mul_eroded_FA.inputs.out_filename = "diffusion_resampled_tensor_FA_masked.mif"
        mrtrix_thr_FA = pe.Node(interface=MRThreshold(
            out_file='FA_th.mif'), name='mrtrix_thr')
        mrtrix_thr_FA.inputs.abs_value = config.single_fib_thr

        flow.connect([
            (inputnode, mrtrix_erode, [("wm_mask_resampled", 'in_file')]),
            (mrtrix_erode, mrtrix_mul_eroded_FA, [('out_file', 'input2')]),
            (mrtrix_tensor_metrics, mrtrix_mul_eroded_FA,
             [('out_fa', 'input1')]),
            (mrtrix_mul_eroded_FA, mrtrix_thr_FA, [('out_file', 'in_file')])
        ])
        # Compute single fiber response function
        mrtrix_rf = pe.Node(
            interface=EstimateResponseForSH(), name='mrtrix_rf')
        # if config.lmax_order != 'Auto':
        mrtrix_rf.inputs.maximum_harmonic_order = int(config.lmax_order)

        mrtrix_rf.inputs.algorithm = 'tournier'
        # mrtrix_rf.inputs.normalise = config.normalize_to_B0
        flow.connect([
            (inputnode, mrtrix_rf, [("diffusion_resampled", "in_file")]),
            (mrtrix_thr_FA, mrtrix_rf, [("thresholded", "mask_image")]),
            (flip_table, mrtrix_rf, [("table", "encoding_file")]),
        ])

        # Perform spherical deconvolution
        mrtrix_CSD = pe.Node(
            interface=ConstrainedSphericalDeconvolution(), name='mrtrix_CSD')
        mrtrix_CSD.inputs.algorithm = 'csd'
        mrtrix_CSD.inputs.maximum_harmonic_order = int(config.lmax_order)
        # mrtrix_CSD.inputs.normalise = config.normalize_to_B0

        convert_CSD = pe.Node(interface=MRConvert(
            out_filename="spherical_harmonics_image.nii.gz"), name='convert_CSD')

        flow.connect([
            (inputnode, mrtrix_CSD, [('diffusion_resampled', 'in_file')]),
            (mrtrix_rf, mrtrix_CSD, [('response', 'response_file')]),
            (mrtrix_rf, outputnode, [('response', 'RF')]),
            (inputnode, mrtrix_CSD, [("wm_mask_resampled", 'mask_image')]),
            (flip_table, mrtrix_CSD, [("table", "encoding_file")]),
            (mrtrix_CSD, convert_CSD, [('spherical_harmonics_image', 'in_file')]),
            (convert_CSD, outputnode, [("converted", "DWI")])
            # (mrtrix_CSD,outputnode,[('spherical_harmonics_image','DWI')])
        ])
    else:
        flow.connect([
            (inputnode, outputnode, [('diffusion_resampled', 'DWI')])
        ])

    return flow
コード例 #5
0
def create_dipy_recon_flow(config):
    flow = pe.Workflow(name="reconstruction")
    inputnode = pe.Node(interface=util.IdentityInterface(fields=["diffusion",
                                                                 "diffusion_resampled",
                                                                 "brain_mask_resampled",
                                                                 "wm_mask_resampled",
                                                                 "bvals",
                                                                 "bvecs"]),
                        name="inputnode")
    outputnode = pe.Node(interface=util.IdentityInterface(
        fields=["DWI", "FA", "AD", "MD", "RD", "fod", "model", "eigVec", "RF", "grad", "bvecs", "shore_maps",
                "mapmri_maps"], mandatory_inputs=True), name="outputnode")

    # Flip gradient table
    flip_bvecs = pe.Node(interface=flipBvec(), name='flip_bvecs')

    flip_bvecs.inputs.flipping_axis = config.flip_table_axis
    flip_bvecs.inputs.delimiter = ' '
    flip_bvecs.inputs.header_lines = 0
    flip_bvecs.inputs.orientation = 'h'
    flow.connect([
        (inputnode, flip_bvecs, [("bvecs", "bvecs")]),
        (flip_bvecs, outputnode, [("bvecs_flipped", "bvecs")])
    ])

    # Compute single fiber voxel mask
    dipy_erode = pe.Node(interface=Erode(
        out_filename="wm_mask_resampled.nii.gz"), name='dipy_erode')
    dipy_erode.inputs.number_of_passes = 1
    dipy_erode.inputs.filtertype = 'erode'

    flow.connect([
        (inputnode, dipy_erode, [("wm_mask_resampled", 'in_file')])
    ])

    if config.imaging_model != 'DSI':
        # Tensor -> EigenVectors / FA, AD, MD, RD maps
        dipy_tensor = pe.Node(
            interface=DTIEstimateResponseSH(), name='dipy_tensor')
        dipy_tensor.inputs.auto = True
        dipy_tensor.inputs.roi_radius = 10
        dipy_tensor.inputs.fa_thresh = config.single_fib_thr

        flow.connect([
            (inputnode, dipy_tensor, [('diffusion_resampled', 'in_file')]),
            (inputnode, dipy_tensor, [('bvals', 'in_bval')]),
            (flip_bvecs, dipy_tensor, [('bvecs_flipped', 'in_bvec')]),
            (dipy_erode, dipy_tensor, [('out_file', 'in_mask')])
        ])

        flow.connect([
            (dipy_tensor, outputnode, [("response", "RF")]),
            (dipy_tensor, outputnode, [("fa_file", "FA")]),
            (dipy_tensor, outputnode, [("ad_file", "AD")]),
            (dipy_tensor, outputnode, [("md_file", "MD")]),
            (dipy_tensor, outputnode, [("rd_file", "RD")])
        ])

        if not config.local_model:
            flow.connect([
                (inputnode, outputnode, [('diffusion_resampled', 'DWI')]),
                (dipy_tensor, outputnode, [("dti_model", "model")])
            ])
            # Tensor -> Eigenvectors
            # mrtrix_eigVectors = pe.Node(interface=Tensor2Vector(),name="mrtrix_eigenvectors")

        # Constrained Spherical Deconvolution
        else:
            # Perform spherical deconvolution
            dipy_CSD = pe.Node(interface=CSD(), name='dipy_CSD')

            # if config.tracking_processing_tool != 'Dipy':
            dipy_CSD.inputs.save_shm_coeff = True
            dipy_CSD.inputs.out_shm_coeff = 'diffusion_shm_coeff.nii.gz'

            if config.tracking_processing_tool == 'MRtrix':
                dipy_CSD.inputs.tracking_processing_tool = 'mrtrix'
            elif config.tracking_processing_tool == 'Dipy':
                dipy_CSD.inputs.tracking_processing_tool = 'dipy'

            # dipy_CSD.inputs.save_fods=True
            # dipy_CSD.inputs.out_fods='diffusion_fODFs.nii.gz'

            if config.lmax_order != 'Auto':
                dipy_CSD.inputs.sh_order = config.lmax_order

            dipy_CSD.inputs.fa_thresh = config.single_fib_thr

            flow.connect([
                (inputnode, dipy_CSD, [('diffusion_resampled', 'in_file')]),
                (inputnode, dipy_CSD, [('bvals', 'in_bval')]),
                (flip_bvecs, dipy_CSD, [('bvecs_flipped', 'in_bvec')]),
                # (dipy_tensor, dipy_CSD,[('out_mask','in_mask')]),
                # (dipy_erode, dipy_CSD,[('out_file','in_mask')]),
                (inputnode, dipy_CSD, [("brain_mask_resampled", 'in_mask')]),
                # (dipy_tensor, dipy_CSD,[('response','response')]),
                (dipy_CSD, outputnode, [('model', 'model')])
            ])

            if config.tracking_processing_tool != 'Dipy':
                flow.connect([
                    (dipy_CSD, outputnode, [('out_shm_coeff', 'DWI')])
                ])
            else:
                flow.connect([
                    (inputnode, outputnode, [('diffusion_resampled', 'DWI')])
                ])
    else:
        # Perform SHORE reconstruction (DSI)

        dipy_SHORE = pe.Node(interface=SHORE(), name='dipy_SHORE')

        if config.tracking_processing_tool == 'MRtrix':
            dipy_SHORE.inputs.tracking_processing_tool = 'mrtrix'
        elif config.tracking_processing_tool == 'Dipy':
            dipy_SHORE.inputs.tracking_processing_tool = 'dipy'

        # if config.tracking_processing_tool != 'Dipy':
        # dipy_SHORE.inputs.save_shm_coeff = True
        # dipy_SHORE.inputs.out_shm_coeff='diffusion_shm_coeff.nii.gz'

        dipy_SHORE.inputs.radial_order = int(config.shore_radial_order)
        dipy_SHORE.inputs.zeta = config.shore_zeta
        dipy_SHORE.inputs.lambda_n = config.shore_lambda_n
        dipy_SHORE.inputs.lambda_l = config.shore_lambda_l
        dipy_SHORE.inputs.tau = config.shore_tau
        dipy_SHORE.inputs.constrain_e0 = config.shore_constrain_e0
        dipy_SHORE.inputs.positive_constraint = config.shore_positive_constraint
        # dipy_SHORE.inputs.save_shm_coeff = True
        # dipy_SHORE.inputs.out_shm_coeff = 'diffusion_shore_shm_coeff.nii.gz'

        shore_maps_merge = pe.Node(
            interface=util.Merge(3), name='merge_shore_maps')

        flow.connect([
            (inputnode, dipy_SHORE, [('diffusion_resampled', 'in_file')]),
            (inputnode, dipy_SHORE, [('bvals', 'in_bval')]),
            (flip_bvecs, dipy_SHORE, [('bvecs_flipped', 'in_bvec')]),
            # (dipy_tensor, dipy_CSD,[('out_mask','in_mask')]),
            # (dipy_erode, dipy_SHORE,[('out_file','in_mask')]),
            # (inputnode,dipy_SHORE,[("wm_mask_resampled",'in_mask')]),
            (inputnode, dipy_SHORE, [("brain_mask_resampled", 'in_mask')]),
            # (dipy_tensor, dipy_CSD,[('response','response')]),
            (dipy_SHORE, outputnode, [('model', 'model')]),
            (dipy_SHORE, outputnode, [('fodf', 'fod')]),
            (dipy_SHORE, outputnode, [('GFA', 'FA')])
        ])

        flow.connect([
            (dipy_SHORE, shore_maps_merge, [('GFA', 'in1'),
                                            ('MSD', 'in2'),
                                            ('RTOP', 'in3')]),
            (shore_maps_merge, outputnode, [('out', 'shore_maps')])
        ])

        flow.connect([
            (inputnode, outputnode, [('diffusion_resampled', 'DWI')])
        ])

    if config.mapmri:
        dipy_MAPMRI = pe.Node(interface=MAPMRI(), name='dipy_mapmri')

        dipy_MAPMRI.inputs.laplacian_regularization = config.laplacian_regularization
        dipy_MAPMRI.inputs.laplacian_weighting = config.laplacian_weighting
        dipy_MAPMRI.inputs.positivity_constraint = config.positivity_constraint
        dipy_MAPMRI.inputs.radial_order = config.radial_order
        dipy_MAPMRI.inputs.small_delta = config.small_delta
        dipy_MAPMRI.inputs.big_delta = config.big_delta

        mapmri_maps_merge = pe.Node(
            interface=util.Merge(8), name='merge_mapmri_maps')

        flow.connect([
            (inputnode, dipy_MAPMRI, [('diffusion_resampled', 'in_file')]),
            (inputnode, dipy_MAPMRI, [('bvals', 'in_bval')]),
            (flip_bvecs, dipy_MAPMRI, [('bvecs_flipped', 'in_bvec')])
        ])

        flow.connect([
            (dipy_MAPMRI, mapmri_maps_merge, [('rtop_file', 'in1'),
                                              ('rtap_file', 'in2'),
                                              ('rtpp_file', 'in3'),
                                              ('msd_file', 'in4'),
                                              ('qiv_file', 'in5'),
                                              ('ng_file', 'in6'),
                                              ('ng_perp_file', 'in7'),
                                              ('ng_para_file', 'in8')]),
            (mapmri_maps_merge, outputnode, [('out', 'mapmri_maps')])
        ])

    return flow
コード例 #6
0
def create_mrtrix_recon_flow(config):
    """Create the reconstruction sub-workflow of the `DiffusionStage` using MRtrix3.

    Parameters
    ----------
    config : DipyReconConfig
        Workflow configuration

    Returns
    -------
    flow : nipype.pipeline.engine.Workflow
        Built reconstruction sub-workflow
    """

    # TODO: Add AD and RD maps
    flow = pe.Workflow(name="reconstruction")
    inputnode = pe.Node(
        interface=util.IdentityInterface(fields=[
            "diffusion", "diffusion_resampled", "wm_mask_resampled", "grad"
        ]),
        name="inputnode",
    )
    outputnode = pe.Node(
        interface=util.IdentityInterface(
            fields=["DWI", "FA", "ADC", "tensor", "eigVec", "RF", "grad"],
            mandatory_inputs=True,
        ),
        name="outputnode",
    )

    # Flip gradient table
    flip_table = pe.Node(interface=FlipTable(), name="flip_table")

    flip_table.inputs.flipping_axis = config.flip_table_axis
    flip_table.inputs.delimiter = " "
    flip_table.inputs.header_lines = 0
    flip_table.inputs.orientation = "v"
    # fmt:off
    flow.connect([
        (inputnode, flip_table, [("grad", "table")]),
        (flip_table, outputnode, [("table", "grad")]),
    ])
    # fmt:on

    # Tensor
    mrtrix_tensor = pe.Node(interface=DWI2Tensor(), name="mrtrix_make_tensor")
    # fmt:off
    flow.connect([
        (inputnode, mrtrix_tensor, [("diffusion_resampled", "in_file")]),
        (flip_table, mrtrix_tensor, [("table", "encoding_file")]),
    ])
    # fmt:on

    # Tensor -> FA map
    mrtrix_tensor_metrics = pe.Node(
        interface=TensorMetrics(out_fa="FA.mif", out_adc="ADC.mif"),
        name="mrtrix_tensor_metrics",
    )
    convert_Tensor = pe.Node(
        interface=MRConvert(out_filename="dwi_tensor.nii.gz"),
        name="convert_tensor")
    convert_FA = pe.Node(interface=MRConvert(out_filename="FA.nii.gz"),
                         name="convert_FA")
    convert_ADC = pe.Node(interface=MRConvert(out_filename="ADC.nii.gz"),
                          name="convert_ADC")
    # fmt:off
    flow.connect([
        (mrtrix_tensor, convert_Tensor, [("tensor", "in_file")]),
        (mrtrix_tensor, mrtrix_tensor_metrics, [("tensor", "in_file")]),
        (mrtrix_tensor_metrics, convert_FA, [("out_fa", "in_file")]),
        (mrtrix_tensor_metrics, convert_ADC, [("out_adc", "in_file")]),
        (convert_Tensor, outputnode, [("converted", "tensor")]),
        (convert_FA, outputnode, [("converted", "FA")]),
        (convert_ADC, outputnode, [("converted", "ADC")]),
    ])
    # fmt:on

    # Tensor -> Eigenvectors
    mrtrix_eigVectors = pe.Node(interface=Tensor2Vector(),
                                name="mrtrix_eigenvectors")
    # fmt:off
    flow.connect([
        (mrtrix_tensor, mrtrix_eigVectors, [("tensor", "in_file")]),
        (mrtrix_eigVectors, outputnode, [("vector", "eigVec")]),
    ])
    # fmt:on

    # Constrained Spherical Deconvolution
    if config.local_model:
        print("CSD true")
        # Compute single fiber voxel mask
        mrtrix_erode = pe.Node(
            interface=Erode(out_filename="wm_mask_res_eroded.nii.gz"),
            name="mrtrix_erode",
        )
        mrtrix_erode.inputs.number_of_passes = 1
        mrtrix_erode.inputs.filtertype = "erode"
        mrtrix_mul_eroded_FA = pe.Node(interface=MRtrix_mul(),
                                       name="mrtrix_mul_eroded_FA")
        mrtrix_mul_eroded_FA.inputs.out_filename = "diffusion_resampled_tensor_FA_masked.mif"
        mrtrix_thr_FA = pe.Node(interface=MRThreshold(out_file="FA_th.mif"),
                                name="mrtrix_thr")
        mrtrix_thr_FA.inputs.abs_value = config.single_fib_thr
        # fmt:off
        flow.connect([
            (inputnode, mrtrix_erode, [("wm_mask_resampled", "in_file")]),
            (mrtrix_erode, mrtrix_mul_eroded_FA, [("out_file", "input2")]),
            (mrtrix_tensor_metrics, mrtrix_mul_eroded_FA, [("out_fa", "input1")
                                                           ]),
            (mrtrix_mul_eroded_FA, mrtrix_thr_FA, [("out_file", "in_file")]),
        ])
        # fmt:on

        # Compute single fiber response function
        mrtrix_rf = pe.Node(interface=EstimateResponseForSH(),
                            name="mrtrix_rf")
        mrtrix_rf.inputs.maximum_harmonic_order = int(config.lmax_order)
        mrtrix_rf.inputs.algorithm = "tournier"
        # mrtrix_rf.inputs.normalise = config.normalize_to_B0
        # fmt:off
        flow.connect([
            (inputnode, mrtrix_rf, [("diffusion_resampled", "in_file")]),
            (mrtrix_thr_FA, mrtrix_rf, [("thresholded", "mask_image")]),
            (flip_table, mrtrix_rf, [("table", "encoding_file")]),
        ])
        # fmt:on

        # Perform spherical deconvolution
        mrtrix_CSD = pe.Node(interface=ConstrainedSphericalDeconvolution(),
                             name="mrtrix_CSD")
        mrtrix_CSD.inputs.algorithm = "csd"
        mrtrix_CSD.inputs.maximum_harmonic_order = int(config.lmax_order)
        # mrtrix_CSD.inputs.normalise = config.normalize_to_B0

        convert_CSD = pe.Node(
            interface=MRConvert(
                out_filename="spherical_harmonics_image.nii.gz"),
            name="convert_CSD",
        )
        # fmt:off
        flow.connect([
            (inputnode, mrtrix_CSD, [("diffusion_resampled", "in_file")]),
            (mrtrix_rf, mrtrix_CSD, [("response", "response_file")]),
            (mrtrix_rf, outputnode, [("response", "RF")]),
            (inputnode, mrtrix_CSD, [("wm_mask_resampled", "mask_image")]),
            (flip_table, mrtrix_CSD, [("table", "encoding_file")]),
            (mrtrix_CSD, convert_CSD, [("spherical_harmonics_image", "in_file")
                                       ]),
            (convert_CSD, outputnode, [("converted", "DWI")])
            # (mrtrix_CSD,outputnode,[('spherical_harmonics_image','DWI')])
        ])
        # fmt:on
    else:
        # fmt:off
        flow.connect([(inputnode, outputnode, [("diffusion_resampled", "DWI")])
                      ])
        # fmt:on
    return flow
コード例 #7
0
def create_dipy_recon_flow(config):
    """Create the reconstruction sub-workflow of the `DiffusionStage` using Dipy.

    Parameters
    ----------
    config : DipyReconConfig
        Workflow configuration

    Returns
    -------
    flow : nipype.pipeline.engine.Workflow
        Built reconstruction sub-workflow
    """
    flow = pe.Workflow(name="reconstruction")
    inputnode = pe.Node(
        interface=util.IdentityInterface(fields=[
            "diffusion",
            "diffusion_resampled",
            "brain_mask_resampled",
            "wm_mask_resampled",
            "bvals",
            "bvecs",
        ]),
        name="inputnode",
    )
    outputnode = pe.Node(
        interface=util.IdentityInterface(
            fields=[
                "DWI",
                "FA",
                "AD",
                "MD",
                "RD",
                "fod",
                "model",
                "eigVec",
                "RF",
                "grad",
                "bvecs",
                "shore_maps",
                "mapmri_maps",
            ],
            mandatory_inputs=True,
        ),
        name="outputnode",
    )

    # Flip gradient table
    flip_bvecs = pe.Node(interface=FlipBvec(), name="flip_bvecs")

    flip_bvecs.inputs.flipping_axis = config.flip_table_axis
    flip_bvecs.inputs.delimiter = " "
    flip_bvecs.inputs.header_lines = 0
    flip_bvecs.inputs.orientation = "h"
    # fmt:off
    flow.connect([
        (inputnode, flip_bvecs, [("bvecs", "bvecs")]),
        (flip_bvecs, outputnode, [("bvecs_flipped", "bvecs")]),
    ])
    # fmt:on

    # Compute single fiber voxel mask
    dipy_erode = pe.Node(
        interface=Erode(out_filename="wm_mask_resampled.nii.gz"),
        name="dipy_erode")
    dipy_erode.inputs.number_of_passes = 1
    dipy_erode.inputs.filtertype = "erode"

    flow.connect([(inputnode, dipy_erode, [("wm_mask_resampled", "in_file")])])

    if config.imaging_model != "DSI":
        # Tensor -> EigenVectors / FA, AD, MD, RD maps
        dipy_tensor = pe.Node(interface=DTIEstimateResponseSH(),
                              name="dipy_tensor")
        dipy_tensor.inputs.auto = True
        dipy_tensor.inputs.roi_radius = 10
        dipy_tensor.inputs.fa_thresh = config.single_fib_thr
        # fmt:off
        flow.connect([
            (inputnode, dipy_tensor, [("diffusion_resampled", "in_file")]),
            (inputnode, dipy_tensor, [("bvals", "in_bval")]),
            (flip_bvecs, dipy_tensor, [("bvecs_flipped", "in_bvec")]),
            (dipy_erode, dipy_tensor, [("out_file", "in_mask")]),
            (dipy_tensor, outputnode, [("response", "RF")]),
            (dipy_tensor, outputnode, [("fa_file", "FA")]),
            (dipy_tensor, outputnode, [("ad_file", "AD")]),
            (dipy_tensor, outputnode, [("md_file", "MD")]),
            (dipy_tensor, outputnode, [("rd_file", "RD")]),
        ])
        # fmt:on

        if not config.local_model:
            # fmt:off
            flow.connect([
                (inputnode, outputnode, [("diffusion_resampled", "DWI")]),
                (dipy_tensor, outputnode, [("dti_model", "model")]),
            ])
            # fmt:on

        # Constrained Spherical Deconvolution
        else:
            # Perform spherical deconvolution
            dipy_CSD = pe.Node(interface=CSD(), name="dipy_CSD")

            dipy_CSD.inputs.save_shm_coeff = True
            dipy_CSD.inputs.out_shm_coeff = "diffusion_shm_coeff.nii.gz"

            if config.tracking_processing_tool == "MRtrix":
                dipy_CSD.inputs.tracking_processing_tool = "mrtrix"
            elif config.tracking_processing_tool == "Dipy":
                dipy_CSD.inputs.tracking_processing_tool = "dipy"

            if config.lmax_order != "Auto":
                dipy_CSD.inputs.sh_order = config.lmax_order

            dipy_CSD.inputs.fa_thresh = config.single_fib_thr
            # fmt:off
            flow.connect([
                (inputnode, dipy_CSD, [("diffusion_resampled", "in_file")]),
                (inputnode, dipy_CSD, [("bvals", "in_bval")]),
                (flip_bvecs, dipy_CSD, [("bvecs_flipped", "in_bvec")]),
                (inputnode, dipy_CSD, [("brain_mask_resampled", "in_mask")]),
                (dipy_CSD, outputnode, [("model", "model")]),
            ])
            # fmt:on

            if config.tracking_processing_tool != "Dipy":
                # fmt:off
                flow.connect([(dipy_CSD, outputnode, [("out_shm_coeff", "DWI")
                                                      ])])
                # fmt:on
            else:
                # fmt:off
                flow.connect([(inputnode, outputnode, [("diffusion_resampled",
                                                        "DWI")])])
                # fmt:on
    else:
        # Perform SHORE reconstruction (DSI)
        dipy_SHORE = pe.Node(interface=SHORE(), name="dipy_SHORE")

        if config.tracking_processing_tool == "MRtrix":
            dipy_SHORE.inputs.tracking_processing_tool = "mrtrix"
        elif config.tracking_processing_tool == "Dipy":
            dipy_SHORE.inputs.tracking_processing_tool = "dipy"

        dipy_SHORE.inputs.radial_order = int(config.shore_radial_order)
        dipy_SHORE.inputs.zeta = config.shore_zeta
        dipy_SHORE.inputs.lambda_n = config.shore_lambda_n
        dipy_SHORE.inputs.lambda_l = config.shore_lambda_l
        dipy_SHORE.inputs.tau = config.shore_tau
        dipy_SHORE.inputs.constrain_e0 = config.shore_constrain_e0
        dipy_SHORE.inputs.positive_constraint = config.shore_positive_constraint

        shore_maps_merge = pe.Node(interface=util.Merge(3),
                                   name="merge_shore_maps")
        # fmt:off
        flow.connect([
            (inputnode, dipy_SHORE, [("diffusion_resampled", "in_file")]),
            (inputnode, dipy_SHORE, [("bvals", "in_bval")]),
            (flip_bvecs, dipy_SHORE, [("bvecs_flipped", "in_bvec")]),
            (inputnode, dipy_SHORE, [("brain_mask_resampled", "in_mask")]),
            (dipy_SHORE, outputnode, [("model", "model")]),
            (dipy_SHORE, outputnode, [("fodf", "fod")]),
            (dipy_SHORE, outputnode, [("GFA", "FA")]),
            (
                dipy_SHORE,
                shore_maps_merge,
                [("GFA", "in1"), ("MSD", "in2"), ("RTOP", "in3")],
            ), (shore_maps_merge, outputnode, [("out", "shore_maps")]),
            (inputnode, outputnode, [("diffusion_resampled", "DWI")])
        ])
        # fmt:on

    if config.mapmri:
        dipy_MAPMRI = pe.Node(interface=MAPMRI(), name="dipy_mapmri")

        dipy_MAPMRI.inputs.laplacian_regularization = config.laplacian_regularization
        dipy_MAPMRI.inputs.laplacian_weighting = config.laplacian_weighting
        dipy_MAPMRI.inputs.positivity_constraint = config.positivity_constraint
        dipy_MAPMRI.inputs.radial_order = config.radial_order
        dipy_MAPMRI.inputs.small_delta = config.small_delta
        dipy_MAPMRI.inputs.big_delta = config.big_delta

        mapmri_maps_merge = pe.Node(interface=util.Merge(8),
                                    name="merge_mapmri_maps")

        # fmt:off
        flow.connect([
            (inputnode, dipy_MAPMRI, [("diffusion_resampled", "in_file")]),
            (inputnode, dipy_MAPMRI, [("bvals", "in_bval")]),
            (flip_bvecs, dipy_MAPMRI, [("bvecs_flipped", "in_bvec")]),
            (dipy_MAPMRI, mapmri_maps_merge, [("rtop_file", "in1"),
                                              ("rtap_file", "in2"),
                                              ("rtpp_file", "in3"),
                                              ("msd_file", "in4"),
                                              ("qiv_file", "in5"),
                                              ("ng_file", "in6"),
                                              ("ng_perp_file", "in7"),
                                              ("ng_para_file", "in8")]),
            (mapmri_maps_merge, outputnode, [("out", "mapmri_maps")]),
        ])
        # fmt:on

    return flow