Пример #1
0
 def __init__(self, inlist=[0], splits=[0], **options):
     from nipype.interfaces.utility import Split
     sp = Split()
     sp.inputs.inlist = inlist
     sp.inputs.splits = splits
     for ef in options:
         setattr(sp.inputs, ef, options[ef])
     self.res = sp.run()
Пример #2
0
def test_Split_outputs():
    output_map = dict()
    outputs = Split.output_spec()

    for key, metadata in output_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(outputs.traits()[key], metakey), value
Пример #3
0
 def pipeline2(self):
     pipeline = self.pipeline(
         name='pipeline2',
         inputs=[
             FilesetSpec('ones', text_format),
             FilesetSpec('twos', text_format)
         ],
         outputs=[FieldSpec('threes', float),
                  FieldSpec('fours', float)],
         desc=("A pipeline that tests loading of requirements in "
               "map nodes"),
         references=[],
     )
     # Convert from DICOM to NIfTI.gz format on input
     merge = pipeline.create_node(Merge(2), "merge")
     maths = pipeline.create_map_node(TestMathWithReq(),
                                      "maths",
                                      iterfield='x',
                                      requirements=[(notinstalled1_req,
                                                     notinstalled2_req,
                                                     first_req),
                                                    second_req])
     split = pipeline.create_node(Split(), 'split')
     split.inputs.splits = [1, 1]
     split.inputs.squeeze = True
     maths.inputs.op = 'add'
     maths.inputs.y = 2
     pipeline.connect_input('ones', merge, 'in1')
     pipeline.connect_input('twos', merge, 'in2')
     pipeline.connect(merge, 'out', maths, 'x')
     pipeline.connect(maths, 'z', split, 'inlist')
     pipeline.connect_output('threes', split, 'out1')
     pipeline.connect_output('fours', split, 'out2')
     return pipeline
Пример #4
0
def test_Split_outputs():
    output_map = dict()
    outputs = Split.output_spec()

    for key, metadata in output_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(outputs.traits()[key], metakey), value
Пример #5
0
    def segmentation_pipeline(self, img_type=2, **name_maps):
        pipeline = self.new_pipeline(
            name='FAST_segmentation',
            name_maps=name_maps,
            inputs=[FilesetSpec('brain', nifti_gz_format)],
            outputs=[FilesetSpec('wm_seg', nifti_gz_format)],
            desc="White matter segmentation of the reference image",
            references=[fsl_cite])

        fast = pipeline.add('fast',
                            fsl.FAST(img_type=img_type,
                                     segments=True,
                                     out_basename='Reference_segmentation'),
                            inputs={'in_files': ('brain', nifti_gz_format)},
                            requirements=[fsl_req.v('5.0.9')]),

        # Determine output field of split to use
        if img_type == 1:
            split_output = 'out3'
        elif img_type == 2:
            split_output = 'out2'
        else:
            raise ArcanaUsageError(
                "'img_type' parameter can either be 1 or 2 (not {})".format(
                    img_type))

        pipeline.add('split',
                     Split(splits=[1, 1, 1], squeeze=True),
                     connect={'inlist': (fast, 'tissue_class_files')},
                     outputs={split_output: ('wm_seg', nifti_gz_format)})

        return pipeline
Пример #6
0
 def pipeline2(self, **name_maps):
     pipeline = self.new_pipeline(
         name='pipeline2',
         desc=("A pipeline that tests loading of requirements in "
               "map nodes"),
         name_maps=name_maps)
     # Convert from DICOM to NIfTI.gz format on input
     merge = pipeline.add("merge", Merge(2))
     maths = pipeline.add(
         "maths",
         TestMathWithReq(),
         iterfield='x',
         requirements=[first_req.v('0.15.9'),
                       second_req.v('1.0.2')])
     split = pipeline.add('split', Split())
     split.inputs.splits = [1, 1]
     split.inputs.squeeze = True
     maths.inputs.op = 'add'
     maths.inputs.y = 2
     pipeline.connect_input('ones', merge, 'in1', text_format)
     pipeline.connect_input('twos', merge, 'in2', text_format)
     pipeline.connect(merge, 'out', maths, 'x')
     pipeline.connect(maths, 'z', split, 'inlist')
     pipeline.connect_output('threes', split, 'out1', text_format)
     pipeline.connect_output('fours', split, 'out2', text_format)
     return pipeline
Пример #7
0
def datasink_base(datasink, datasource, workflow, sessions):

    split_ds_nodes = []
    for i in range(len(sequences)):
        split_ds = nipype.Node(interface=Split(), name='split_ds{}'.format(i))
        split_ds.inputs.splits = [1]*len(sessions)
        split_ds_nodes.append(split_ds)


    for i, node in enumerate(split_ds_nodes):
        if len(sessions) > 1:
            workflow.connect(datasource, sequences[i], node,
                             'inlist')
            for j, sess in enumerate(sessions):
                workflow.connect(node, 'out{}'.format(j+1),
                                 datasink, 'results.subid.{0}.@{1}'
                                 .format(sess, sequences[i]))
        else:
            workflow.connect(datasource, sequences[i], datasink,
                             'results.subid.{0}.@{1}'.format(sessions[0],
                                                             sequences[i]))
    workflow.connect(datasource, 'reference', datasink,
                     'results.subid.REF.@ref_ct')

    workflow.connect(datasource, 't1_0', datasink,
                     'results.subid.T10.@ref_t1')
    return workflow
Пример #8
0
    def segmentation_pipeline(self, img_type=2, **kwargs):
        pipeline = self.create_pipeline(
            name='FAST_segmentation',
            inputs=[DatasetSpec('brain', nifti_gz_format)],
            outputs=[DatasetSpec('wm_seg', nifti_gz_format)],
            desc="White matter segmentation of the reference image",
            version=1,
            citations=[fsl_cite],
            **kwargs)

        fast = pipeline.create_node(fsl.FAST(),
                                    name='fast',
                                    requirements=[fsl509_req])
        fast.inputs.img_type = img_type
        fast.inputs.segments = True
        fast.inputs.out_basename = 'Reference_segmentation'
        pipeline.connect_input('brain', fast, 'in_files')
        split = pipeline.create_node(Split(), name='split')
        split.inputs.splits = [1, 1, 1]
        split.inputs.squeeze = True
        pipeline.connect(fast, 'tissue_class_files', split, 'inlist')
        if img_type == 1:
            pipeline.connect_output('wm_seg', split, 'out3')
        elif img_type == 2:
            pipeline.connect_output('wm_seg', split, 'out2')
        else:
            raise ArcanaUsageError(
                "'img_type' parameter can either be 1 or 2 (not {})".format(
                    img_type))

        return pipeline
Пример #9
0
 def pipeline2(self, **name_maps):
     pipeline = self.new_pipeline('pipeline2',
                                  desc="",
                                  citations=[],
                                  name_maps=name_maps)
     split = pipeline.add('split',
                          Split(splits=[1, 1, 1], squeeze=True),
                          inputs={'inlist': ('derived_field1', float)})
     math1 = pipeline.add('math1',
                          TestMath(op='add', as_file=True),
                          inputs={
                              'y': ('acquired_fileset3', text_format),
                              'x': (split, 'out3')
                          },
                          requirements=[a_req.v('1.0')])
     math2 = pipeline.add('math2',
                          TestMath(op='add', as_file=True),
                          inputs={
                              'y': ('acquired_field1', float),
                              'x': (math1, 'z')
                          },
                          outputs={'derived_fileset1': ('z', text_format)},
                          requirements=[c_req.v(0.1)])
     pipeline.add('math3',
                  TestMath(op='sub', as_file=False, y=-1),
                  inputs={'x': (math2, 'z')},
                  outputs={'derived_field4': ('z', float)})
     return pipeline
Пример #10
0
def test_Split_inputs():
    input_map = dict(ignore_exception=dict(nohash=True,
    usedefault=True,
    ),
    inlist=dict(mandatory=True,
    ),
    splits=dict(mandatory=True,
    ),
    )
    inputs = Split.input_spec()

    for key, metadata in input_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(inputs.traits()[key], metakey), value
Пример #11
0
def test_Split_inputs():
    input_map = dict(
        ignore_exception=dict(
            nohash=True,
            usedefault=True,
        ),
        inlist=dict(mandatory=True, ),
        splits=dict(mandatory=True, ),
    )
    inputs = Split.input_spec()

    for key, metadata in input_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(inputs.traits()[key], metakey), value
Пример #12
0
    def datasink(self, workflow, workflow_datasink):

        datasource = self.data_source
        sequences1 = [
            x for x in datasource.inputs.field_template.keys()
            if x != 't1_0' and x != 'reference' and x != 'rt'
            and x != 'rt_dose' and x != 'doses' and x != 'rts_dcm'
            and x != 'rtstruct' and x != 'physical' and x != 'rbe'
            and x != 'rtct' and x != 'rtct_nifti'
        ]
        rt = [x for x in datasource.inputs.field_template.keys() if x == 'rt']

        split_ds_nodes = []
        for i in range(len(sequences1)):
            sessions_wit_seq = [
                x for y in self.sessions for x in glob.glob(
                    os.path.join(self.base_dir, self.sub_id, y,
                                 sequences1[i].upper() + '.nii.gz'))
            ]
            split_ds = nipype.Node(interface=Split(),
                                   name='split_ds{}'.format(i))
            split_ds.inputs.splits = [1] * len(sessions_wit_seq)
            split_ds_nodes.append(split_ds)

            if len(sessions_wit_seq) > 1:
                workflow.connect(datasource, sequences1[i], split_ds, 'inlist')
                for j, sess in enumerate(sessions_wit_seq):
                    sess_name = sess.split('/')[-2]
                    workflow.connect(
                        split_ds, 'out{}'.format(j + 1), workflow_datasink,
                        'results.subid.{0}.@{1}'.format(
                            sess_name, sequences1[i]))
            elif len(sessions_wit_seq) == 1:
                workflow.connect(
                    datasource, sequences1[i], workflow_datasink,
                    'results.subid.{0}.@{1}'.format(
                        sessions_wit_seq[0].split('/')[-2], sequences1[i]))
        if self.reference:
            workflow.connect(datasource, 'reference', workflow_datasink,
                             'results.subid.REF.@ref_ct')
        if self.t10:
            workflow.connect(datasource, 't1_0', workflow_datasink,
                             'results.subid.T10.@ref_t1')
        if rt:
            workflow.connect(datasource, 'rt', workflow_datasink,
                             'results.subid.@rt')
        return workflow
Пример #13
0
        datasource.inputs.sub_id = sub.split('/')[-1]
        datasource.inputs.sessions = sessions
        datasource.inputs.ref_tp = ref_tp

        rs_ref = nipype.Node(interface=ResampleImage(), name='rs_ref')
        rs_ref.inputs.new_size = '1x1x1'
        rs_ref.inputs.mode = 0
        rs_ref.inputs.interpolation = 0
        rs_ref.inputs.dimensions = 3

        merge_1 = nipype.MapNode(interface=Merge(2),
                                 iterfield=['in1', 'in2'],
                                 name='merge_1')
        merge_1.inputs.ravel_inputs = True

        split_1 = nipype.MapNode(interface=Split(),
                                 iterfield=['inlist'],
                                 name='split_1')
        split_1.inputs.squeeze = True
        split_1.inputs.splits = [1, 2]

        fast_1 = nipype.MapNode(interface=fsl.FAST(),
                                iterfield=['in_files'],
                                name='fast_1')
        fast_1.inputs.img_type = 1
        fast_1.inputs.segments = True

        fast_ref = nipype.Node(interface=fsl.FAST(), name='fast_ref')
        fast_ref.inputs.img_type = 1
        fast_ref.inputs.segments = True
Пример #14
0
def longitudinal_registration(sub_id,
                              datasource,
                              sessions,
                              reference,
                              result_dir,
                              nipype_cache,
                              bet_workflow=None):
    """
    This is a workflow to register multi-modalities MR (T2, T1KM, FLAIR) to their 
    reference T1 image, in multiple time-points cohort. In particular, for each 
    subject, this workflow will register the MR images in each time-point (tp)
    to the corresponding T1, then it will register all the T1 images to a reference T1
    (the one that is the closest in time to the radiotherapy session), and finally the
    reference T1 to the BPLCT. At the end, all the MR images will be saved both in T1 space
    (for each tp) and in CT space.
    """
    reg2T1 = nipype.MapNode(interface=AntsRegSyn(),
                            iterfield=['input_file'],
                            name='reg2T1')
    reg2T1.inputs.transformation = 's'
    reg2T1.inputs.num_dimensions = 3
    reg2T1.inputs.num_threads = 6

    if reference:
        regT12CT = nipype.MapNode(interface=AntsRegSyn(),
                                  iterfield=['input_file'],
                                  name='regT12CT')
        regT12CT.inputs.transformation = 'r'
        regT12CT.inputs.num_dimensions = 3
        regT12CT.inputs.num_threads = 4

    reg_nodes = []
    for i in range(3):
        reg = nipype.MapNode(interface=AntsRegSyn(),
                             iterfield=['input_file', 'ref_file'],
                             name='ants_reg{}'.format(i))
        reg.inputs.transformation = 'r'
        reg.inputs.num_dimensions = 3
        reg.inputs.num_threads = 4
        reg.inputs.interpolation = 'BSpline'
        reg_nodes.append(reg)

    apply_mask_nodes = []
    for i in range(3):
        masking = nipype.MapNode(interface=ApplyMask(),
                                 iterfield=['in_file', 'mask_file'],
                                 name='masking{}'.format(i))
        apply_mask_nodes.append(masking)

    apply_ts_nodes = []
    for i in range(3):
        apply_ts = nipype.MapNode(interface=ApplyTransforms(),
                                  iterfield=['input_image', 'transforms'],
                                  name='apply_ts{}'.format(i))
        apply_ts_nodes.append(apply_ts)
    # Apply ts nodes for T1_ref normalization
    apply_ts_nodes1 = []
    for i in range(3):
        apply_ts = nipype.MapNode(interface=ApplyTransforms(),
                                  iterfield=['input_image', 'transforms'],
                                  name='apply_ts1{}'.format(i))
        apply_ts_nodes1.append(apply_ts)

    split_ds_nodes = []
    for i in range(4):
        split_ds = nipype.Node(interface=Split(), name='split_ds{}'.format(i))
        split_ds.inputs.splits = [1] * len(sessions)
        split_ds_nodes.append(split_ds)

    apply_ts_t1 = nipype.MapNode(interface=ApplyTransforms(),
                                 iterfield=['input_image', 'transforms'],
                                 name='apply_ts_t1')
    merge_nodes = []
    if reference:
        iterfields = ['in1', 'in2', 'in3', 'in4']
        iterfields_t1 = ['in1', 'in2', 'in3']
        if_0 = 2
    else:
        iterfields = ['in1', 'in2', 'in3']
        iterfields_t1 = ['in1', 'in2']
        if_0 = 1

    for i in range(3):
        merge = nipype.MapNode(interface=Merge(len(iterfields)),
                               iterfield=iterfields,
                               name='merge{}'.format(i))
        merge.inputs.ravel_inputs = True
        merge_nodes.append(merge)
    # Merging transforms for normalization to T1_ref
    merge_nodes1 = []
    for i in range(3):
        merge = nipype.MapNode(interface=Merge(3),
                               iterfield=['in1', 'in2', 'in3'],
                               name='merge1{}'.format(i))
        merge.inputs.ravel_inputs = True
        merge_nodes1.append(merge)

    merge_ts_t1 = nipype.MapNode(interface=Merge(len(iterfields_t1)),
                                 iterfield=iterfields_t1,
                                 name='merge_t1')
    merge_ts_t1.inputs.ravel_inputs = True

    # have to create a fake merge of the transformation from t10 to CT in order
    # to have the same number if matrices as input in mapnode
    fake_merge = nipype.Node(interface=Merge(len(sessions)), name='fake_merge')

    datasink = nipype.Node(nipype.DataSink(base_directory=result_dir),
                           "datasink")

    substitutions = [('subid', sub_id)]
    for i, session in enumerate(sessions):
        substitutions += [('session'.format(i), session)]
        substitutions += [('_masking0{}/antsregWarped_masked.nii.gz'.format(i),
                           session + '/' + 'CT1_preproc.nii.gz')]
        substitutions += [('_reg2T1{}/antsreg0GenericAffine.mat'.format(i),
                           session + '/' + 'reg2T1_ref.mat')]
        substitutions += [('_reg2T1{}/antsreg1Warp.nii.gz'.format(i),
                           session + '/' + 'reg2T1_ref_warp.nii.gz')]
        substitutions += [('_reg2T1{}/antsregWarped.nii.gz'.format(i),
                           session + '/' + 'T1_reg2T1_ref.nii.gz')]
        substitutions += [('_regT12CT{}/antsreg0GenericAffine.mat'.format(i),
                           '/regT1_ref2CT.mat')]
        substitutions += [('_masking1{}/antsregWarped_masked.nii.gz'.format(i),
                           session + '/' + 'T2_preproc.nii.gz')]
        substitutions += [('_masking2{}/antsregWarped_masked.nii.gz'.format(i),
                           session + '/' + 'FLAIR_preproc.nii.gz')]
        substitutions += [('_apply_ts0{}/CT1_trans.nii.gz'.format(i),
                           session + '/' + 'CT1_reg2CT.nii.gz')]
        substitutions += [('_apply_ts1{}/T2_trans.nii.gz'.format(i),
                           session + '/' + 'T2_reg2CT.nii.gz')]
        substitutions += [('_apply_ts2{}/FLAIR_trans.nii.gz'.format(i),
                           session + '/' + 'FLAIR_reg2CT.nii.gz')]
        substitutions += [('_apply_ts_t1{}/T1_trans.nii.gz'.format(i),
                           session + '/' + 'T1_reg2CT.nii.gz')]
        substitutions += [('_apply_ts10{}/CT1_trans.nii.gz'.format(i),
                           session + '/' + 'CT1_reg2T1_ref.nii.gz')]
        substitutions += [('_apply_ts11{}/T2_trans.nii.gz'.format(i),
                           session + '/' + 'T2_reg2T1_ref.nii.gz')]
        substitutions += [('_apply_ts12{}/FLAIR_trans.nii.gz'.format(i),
                           session + '/' + 'FLAIR_reg2T1_ref.nii.gz')]

    datasink.inputs.substitutions = substitutions
    # Create Workflow
    workflow = nipype.Workflow('registration_workflow', base_dir=nipype_cache)

    for i, reg in enumerate(reg_nodes):
        workflow.connect(datasource, SEQUENCES[i + 1], reg, 'input_file')
        workflow.connect(datasource, SEQUENCES[0], reg, 'ref_file')
    # bring every MR in CT space
    for i, node in enumerate(apply_ts_nodes):
        workflow.connect(datasource, SEQUENCES[i + 1], node, 'input_image')
        if reference:
            workflow.connect(datasource, 'reference', node, 'reference_image')
        else:
            workflow.connect(datasource, 't1_0', node, 'reference_image')
        workflow.connect(merge_nodes[i], 'out', node, 'transforms')
        workflow.connect(node, 'output_image', datasink,
                         'results.subid.@{}_reg2CT'.format(SEQUENCES[i + 1]))
    # bring every MR in T1_ref space
    for i, node in enumerate(apply_ts_nodes1):
        workflow.connect(datasource, SEQUENCES[i + 1], node, 'input_image')
        workflow.connect(datasource, 't1_0', node, 'reference_image')
        workflow.connect(merge_nodes1[i], 'out', node, 'transforms')
        workflow.connect(
            node, 'output_image', datasink,
            'results.subid.@{}_reg2T1_ref'.format(SEQUENCES[i + 1]))

    for i, node in enumerate(merge_nodes):
        workflow.connect(reg_nodes[i], 'regmat', node, 'in{}'.format(if_0 + 2))
        workflow.connect(reg2T1, 'regmat', node, 'in{}'.format(if_0 + 1))
        workflow.connect(reg2T1, 'warp_file', node, 'in{}'.format(if_0))
        if reference:
            workflow.connect(fake_merge, 'out', node, 'in1')

    for i, node in enumerate(merge_nodes1):
        workflow.connect(reg_nodes[i], 'regmat', node, 'in3')
        workflow.connect(reg2T1, 'regmat', node, 'in2')
        workflow.connect(reg2T1, 'warp_file', node, 'in1')

    for i, mask in enumerate(apply_mask_nodes):
        workflow.connect(reg_nodes[i], 'reg_file', mask, 'in_file')
        if bet_workflow is not None:
            workflow.connect(bet_workflow, 'bet.out_mask', mask, 'mask_file')
        else:
            workflow.connect(datasource, 't1_mask', mask, 'mask_file')
        workflow.connect(mask, 'out_file', datasink,
                         'results.subid.@{}_preproc'.format(SEQUENCES[i + 1]))
    if bet_workflow is not None:
        workflow.connect(bet_workflow, 'bet.out_file', reg2T1, 'input_file')
        workflow.connect(bet_workflow, 't1_0_bet.out_file', reg2T1, 'ref_file')
    else:
        workflow.connect(datasource, 't1_bet', reg2T1, 'input_file')
        workflow.connect(datasource, 't1_0_bet', reg2T1, 'ref_file')

    if reference:
        for i, sess in enumerate(sessions):
            workflow.connect(regT12CT, 'regmat', fake_merge,
                             'in{}'.format(i + 1))
            workflow.connect(regT12CT, 'regmat', datasink,
                             'results.subid.{0}.@regT12CT_mat'.format(sess))
        workflow.connect(datasource, 'reference', regT12CT, 'ref_file')
        workflow.connect(datasource, 't1_0', regT12CT, 'input_file')
        workflow.connect(fake_merge, 'out', merge_ts_t1, 'in1')
        workflow.connect(datasource, 'reference', apply_ts_t1,
                         'reference_image')
    else:
        workflow.connect(datasource, 't1_0', apply_ts_t1, 'reference_image')

    workflow.connect(datasource, 't1', apply_ts_t1, 'input_image')

    workflow.connect(merge_ts_t1, 'out', apply_ts_t1, 'transforms')
    workflow.connect(reg2T1, 'regmat', merge_ts_t1, 'in{}'.format(if_0 + 1))
    workflow.connect(reg2T1, 'warp_file', merge_ts_t1, 'in{}'.format(if_0))

    workflow.connect(reg2T1, 'warp_file', datasink,
                     'results.subid.@reg2CT_warp')
    workflow.connect(reg2T1, 'regmat', datasink, 'results.subid.@reg2CT_mat')
    workflow.connect(reg2T1, 'reg_file', datasink, 'results.subid.@T12T1_ref')
    workflow.connect(apply_ts_t1, 'output_image', datasink,
                     'results.subid.@T1_reg2CT')

    if bet_workflow is not None:
        workflow = datasink_base(datasink, datasource, workflow, sessions,
                                 reference)
    else:
        workflow = datasink_base(datasink,
                                 datasource,
                                 workflow,
                                 sessions,
                                 reference,
                                 extra_nodes=['t1_bet'])

    return workflow
Пример #15
0
def spm_anat_to_diff_coregistration(wf_name="spm_anat_to_diff_coregistration"):
    """ Co-register the anatomical image and other images in anatomical space to
    the average B0 image.

    This estimates an affine transform from anat to diff space, applies it to
    the brain mask and an atlas.

    Nipype Inputs
    -------------
    dti_co_input.avg_b0: traits.File
        path to the average B0 image from the diffusion MRI.
        This image should come from a motion and Eddy currents
        corrected diffusion image.

    dti_co_input.anat: traits.File
        path to the high-contrast anatomical image.

    dti_co_input.tissues: traits.File
        paths to the NewSegment c*.nii output files, in anatomical space

    dti_co_input.atlas_anat: traits.File
        Atlas in subject anatomical space.

    Nipype Outputs
    --------------
    dti_co_output.anat_diff: traits.File
        Anatomical image in diffusion space.

    dti_co_output.tissues_diff: traits.File
        Tissues images in diffusion space.

    dti_co_output.brain_mask_diff: traits.File
        Brain mask for diffusion image.

    dti_co_output.atlas_diff: traits.File
        Atlas image warped to diffusion space.
        If the `atlas_file` option is an existing file and `normalize_atlas` is True.

    Nipype Workflow Dependencies
    ----------------------------
    This workflow depends on:
    - spm_anat_preproc

    Returns
    -------
    wf: nipype Workflow
    """
    # specify input and output fields
    in_fields = ["avg_b0", "tissues", "anat"]
    out_fields = [
        "anat_diff",
        "tissues_diff",
        "brain_mask_diff",
    ]

    do_atlas, _ = check_atlas_file()
    if do_atlas:
        in_fields += ["atlas_anat"]
        out_fields += ["atlas_diff"]

    # input interface
    dti_input = pe.Node(IdentityInterface(fields=in_fields,
                                          mandatory_inputs=True),
                        name="dti_co_input")

    gunzip_b0 = pe.Node(Gunzip(), name="gunzip_b0")
    coreg_b0 = setup_node(spm_coregister(cost_function="mi"), name="coreg_b0")

    # co-registration
    brain_sel = pe.Node(Select(index=[0, 1, 2]), name="brain_sel")
    coreg_split = pe.Node(Split(splits=[1, 2], squeeze=True),
                          name="coreg_split")

    brain_merge = setup_node(MultiImageMaths(), name="brain_merge")
    brain_merge.inputs.op_string = "-add '%s' -add '%s' -abs -kernel gauss 4 -dilM -ero -kernel gauss 1 -dilM -bin"
    brain_merge.inputs.out_file = "brain_mask_diff.nii.gz"

    # output interface
    dti_output = pe.Node(IdentityInterface(fields=out_fields),
                         name="dti_co_output")

    # Create the workflow object
    wf = pe.Workflow(name=wf_name)

    # Connect the nodes
    wf.connect([
        # co-registration
        (dti_input, coreg_b0, [("anat", "source")]),
        (dti_input, brain_sel, [("tissues", "inlist")]),
        (brain_sel, coreg_b0, [(("out", flatten_list), "apply_to_files")]),
        (dti_input, gunzip_b0, [("avg_b0", "in_file")]),
        (gunzip_b0, coreg_b0, [("out_file", "target")]),
        (coreg_b0, coreg_split, [("coregistered_files", "inlist")]),
        (coreg_split, brain_merge, [("out1", "in_file")]),
        (coreg_split, brain_merge, [("out2", "operand_files")]),

        # output
        (coreg_b0, dti_output, [("coregistered_source", "anat_diff")]),
        (coreg_b0, dti_output, [("coregistered_files", "tissues_diff")]),
        (brain_merge, dti_output, [("out_file", "brain_mask_diff")]),
    ])

    # add more nodes if to perform atlas registration
    if do_atlas:
        coreg_atlas = setup_node(spm_coregister(cost_function="mi"),
                                 name="coreg_atlas")

        # set the registration interpolation to nearest neighbour.
        coreg_atlas.inputs.write_interp = 0
        wf.connect([
            (dti_input, coreg_atlas, [
                ("anat", "source"),
                ("atlas_anat", "apply_to_files"),
            ]),
            (gunzip_b0, coreg_atlas, [("out_file", "target")]),
            (coreg_atlas, dti_output, [("coregistered_files", "atlas_diff")]),
        ])

    return wf