def init_params(self):
     eddyc = fsl.Eddy()
     eddyc.inputs.in_file = self.in_file
     eddyc.inputs.in_bval = self.bval
     eddyc.inputs.in_bvec = self.bvec
     eddyc.inputs.in_index = self.idx_file
     eddyc.inputs.in_acqp = self.acq
     eddyc.inputs.in_mask = self.brain_mask
     eddyc.inputs.in_topup_fieldcoef = self.fieldcoef
     eddyc.inputs.in_topup_movpar = self.fieldcoef.replace(
         "fieldcoef.nii.gz", "movpar.txt")
     eddyc.inputs.out_base = os.path.join(os.path.dirname(self.in_file),
                                          "tmp", "diff_corrected")
     # eddyc.inputs.out_corrected = os.path.join(os.path.dirname(self.in_file),'diff_corrected.nii.gz')
     print(eddyc.cmdline)
     return eddyc
예제 #2
0
 def run_eddy(self):
     eddy = fsl.Eddy()
     eddy.inputs.in_file = self.dwi_base + '.nii.gz'
     eddy.inputs.in_mask = self.mask_file
     eddy.inputs.in_index = self.index_file
     eddy.inputs.in_acqp = self.acq_file
     eddy.inputs.in_bvec = self.bvec_file
     eddy.inputs.in_bval = self.bval_file
     # BUG IN NIPYPE; the following is expecting an 'existing file', but must be passed the topup base name for eddy to run.
     # It looks like this was fixed recently. Until we get the updated version, use this work-around:
     #eddy.inputs.in_topup = topup_out.outputs.out_fieldcoef
     eddy.inputs.args = '--topup=' + self.topup_out
     eddy.inputs.out_base = self.eddy_out
     eddy.inputs.environ = {'OMP_NUM_THREADS': '12'}
     #print(eddy.cmdline)
     res = eddy.run()
예제 #3
0
def eddy_correct(
    in_file: Path,
    mask: Path,
    index: Path,
    acq: Path,
    bvec: Path,
    bval: Path,
    fieldcoef: Path,
    movpar: Path,
    out_base: Path,
):
    """
    Generating FSL's eddy-currents correction tool, eddy, with specific inputs.
    Arguments:
        in_file {Path} -- [Path to dwi file]
        mask {Path} -- [Path to brain mask file]
        index {Path} -- [Path to index.txt file]
        acq {Path} -- [Path to datain.txt file]
        bvec {Path} -- [Path to bvec file (extracted automatically when converting the dicom file)]
        bval {Path} -- [Path to bval file (extracted automatically when converting the dicom file)]
        fieldcoef {Path} -- [Path to field coeffient as extracted after topup procedure]
        movpar {Path} -- [Path to moving parameters as extracted after topup procedure]
        out_base {Path} -- [Path to eddy's output base name]

    Returns:
        eddy [type] -- [nipype's FSL eddy's interface, generated with specific inputs]
    """
    eddy = fsl.Eddy()
    eddy.inputs.in_file = in_file
    eddy.inputs.in_mask = mask
    eddy.inputs.in_index = index
    eddy.inputs.in_acqp = acq
    eddy.inputs.in_bvec = bvec
    eddy.inputs.in_bval = bval
    eddy.inputs.in_topup_fieldcoef = fieldcoef
    eddy.inputs.in_topup_movpar = movpar
    eddy.inputs.out_base = str(out_base)
    return eddy
예제 #4
0
def all_fsl_pipeline(name='fsl_all_correct',
                     epi_params=dict(echospacing=0.77e-3,
                                     acc_factor=3,
                                     enc_dir='y-'),
                     altepi_params=dict(echospacing=0.77e-3,
                                        acc_factor=3,
                                        enc_dir='y')):
    """
    Workflow that integrates FSL ``topup`` and ``eddy``.


    .. warning:: this workflow rotates the gradients table (*b*-vectors)
      [Leemans09]_.


    .. warning:: this workflow does not perform jacobian modulation of each
      *DWI* [Jones10]_.


    Examples
    --------

    >>> from nipype.workflows.dmri.fsl.artifacts import all_fsl_pipeline
    >>> allcorr = all_fsl_pipeline()
    >>> allcorr.inputs.inputnode.in_file = 'epi.nii'
    >>> allcorr.inputs.inputnode.alt_file = 'epi_rev.nii'
    >>> allcorr.inputs.inputnode.in_bval = 'diffusion.bval'
    >>> allcorr.inputs.inputnode.in_bvec = 'diffusion.bvec'
    >>> allcorr.run() # doctest: +SKIP

    """

    inputnode = pe.Node(niu.IdentityInterface(
        fields=['in_file', 'in_bvec', 'in_bval', 'alt_file']),
                        name='inputnode')

    outputnode = pe.Node(
        niu.IdentityInterface(fields=['out_file', 'out_mask', 'out_bvec']),
        name='outputnode')

    def gen_index(in_file):
        import numpy as np
        import nibabel as nb
        import os
        out_file = os.path.abspath('index.txt')
        vols = nb.load(in_file).get_data().shape[-1]
        np.savetxt(out_file, np.ones((vols, )).T)
        return out_file

    gen_idx = pe.Node(niu.Function(input_names=['in_file'],
                                   output_names=['out_file'],
                                   function=gen_index),
                      name='gen_index')
    avg_b0_0 = pe.Node(niu.Function(input_names=['in_dwi', 'in_bval'],
                                    output_names=['out_file'],
                                    function=b0_average),
                       name='b0_avg_pre')
    bet_dwi0 = pe.Node(fsl.BET(frac=0.3, mask=True, robust=True),
                       name='bet_dwi_pre')

    sdc = sdc_peb(epi_params=epi_params, altepi_params=altepi_params)
    ecc = pe.Node(fsl.Eddy(method='jac'), name='fsl_eddy')
    rot_bvec = pe.Node(niu.Function(input_names=['in_bvec', 'eddy_params'],
                                    output_names=['out_file'],
                                    function=eddy_rotate_bvecs),
                       name='Rotate_Bvec')
    avg_b0_1 = pe.Node(niu.Function(input_names=['in_dwi', 'in_bval'],
                                    output_names=['out_file'],
                                    function=b0_average),
                       name='b0_avg_post')
    bet_dwi1 = pe.Node(fsl.BET(frac=0.3, mask=True, robust=True),
                       name='bet_dwi_post')

    wf = pe.Workflow(name=name)
    wf.connect([(inputnode, avg_b0_0, [('in_file', 'in_dwi'),
                                       ('in_bval', 'in_bval')]),
                (avg_b0_0, bet_dwi0, [('out_file', 'in_file')]),
                (bet_dwi0, sdc, [('mask_file', 'inputnode.in_mask')]),
                (inputnode, sdc, [('in_file', 'inputnode.in_file'),
                                  ('alt_file', 'inputnode.alt_file'),
                                  ('in_bval', 'inputnode.in_bval')]),
                (sdc, ecc, [('topup.out_enc_file', 'in_acqp'),
                            ('topup.out_fieldcoef', 'in_topup_fieldcoef'),
                            ('topup.out_movpar', 'in_topup_movpar')]),
                (bet_dwi0, ecc, [('mask_file', 'in_mask')]),
                (inputnode, gen_idx, [('in_file', 'in_file')]),
                (inputnode, ecc, [('in_file', 'in_file'),
                                  ('in_bval', 'in_bval'),
                                  ('in_bvec', 'in_bvec')]),
                (gen_idx, ecc, [('out_file', 'in_index')]),
                (inputnode, rot_bvec, [('in_bvec', 'in_bvec')]),
                (ecc, rot_bvec, [('out_parameter', 'eddy_params')]),
                (ecc, avg_b0_1, [('out_corrected', 'in_dwi')]),
                (inputnode, avg_b0_1, [('in_bval', 'in_bval')]),
                (avg_b0_1, bet_dwi1, [('out_file', 'in_file')]),
                (ecc, outputnode, [('out_corrected', 'out_file')]),
                (rot_bvec, outputnode, [('out_file', 'out_bvec')]),
                (bet_dwi1, outputnode, [('mask_file', 'out_mask')])])
    return wf
예제 #5
0
import .dwi_preproc.custom_functions as custom_nodes
import nipype.interfaces.fsl as fsl
import nipype.interfaces.io as io

#Wraps dwidenoise function.
dwidenoise = pe.Node(interface = custom_nodes.dwidenoise(), name='dwidenoise')

#Wraps unring function.
unring = pe.Node(interface = custom_nodes.Unring(), name='unring')

#Wraps fslval function.
fslval = pe.Node(interface = custom_nodes.fslval(), name='fslval')

#Wraps the executable command ``eddy_openmp``.
fsl_Eddy = pe.Node(interface = fsl.Eddy(), name='fsl_Eddy')

#Wraps the executable command ``flirt``.
fsl_FLIRT = pe.Node(interface = fsl.FLIRT(), name='fsl_FLIRT')

#Flexibly collect data from disk to feed into workflows.
io_SelectFiles = pe.Node(io.SelectFiles(templates={}), name='io_SelectFiles')

#Create a workflow to connect all those nodes
analysisflow = nipype.Workflow('MyWorkflow')
analysisflow.connect(dwidenoise, "out_file", unring, "in_file")
analysisflow.connect(fslval, "out_file", fsl_Eddy, "in_index")
analysisflow.connect(unring, "out_file", fsl_Eddy, "in_file")
analysisflow.connect(fsl_Eddy, "out_corrected", fsl_FLIRT, "in_file")
analysisflow.connect(io_SelectFiles, "T1", fsl_FLIRT, "reference")
analysisflow.connect(io_SelectFiles, "acqp", fsl_Eddy, "in_acqp")
#-----------------------------------------------------------------------------------------------------
# In[6]:
bval =  '/home/in/aeed/Work/October_Acquistion/bval_20'
bvec =  '/home/in/aeed/Work/October_Acquistion/bvec_20'
acqparams = '/home/in/aeed/Work/October_Acquistion/acqparams.txt'
index =  '/home/in/aeed/Work/October_Acquistion/index_20.txt'

VBM_DTI_Template = '/home/in/aeed/Work/October_Acquistion/VBM_DTI.nii.gz'
Wax_FA_Template = '/home/in/aeed/Work/October_Acquistion/FMRIB58_FA_2mm.nii.gz'
Study_Template = '/home/in/aeed/Work/October_Acquistion/FA_Template_Cluster.nii.gz'
#-----------------------------------------------------------------------------------------------------
# In[7]:
#Eddy Current correction using the new function Eddy instead of Eddy_correct

eddy = Node (fsl.Eddy(), name = 'eddy')
eddy.inputs.in_acqp  = acqparams
eddy.inputs.in_bval  = bval
eddy.inputs.in_bvec  = bvec
eddy.inputs.in_index = index
eddy.inputs.niter = 10
#-----------------------------------------------------------------------------------------------------
# In[7]:
#Fit the tensor

fit_tensor = Node (fsl.DTIFit(), name = 'fit_tensor')
fit_tensor.inputs.bvals = '/home/in/aeed/Work/October_Acquistion/bval_20'
fit_tensor.inputs.bvecs = '/home/in/aeed/Work/October_Acquistion/bvec_20'
fit_tensor.inputs.save_tensor = True
# fit_tensor.inputs.wls = True #Fit the tensor with wighted least squares, try this one
#-----------------------------------------------------------------------------------------------------
예제 #7
0
def run_process_dwi(wf_dir, subject, sessions, args, study, prep_pipe="mrtrix", acq_str="", ants_quick=False):
    wf_name = "dwi__prep_{}".format(prep_pipe)
    wf = Workflow(name=wf_name)
    wf.base_dir = wf_dir
    wf.config['execution']['crashdump_dir'] = os.path.join(args.output_dir, wf_name, "crash")

    if sessions:
        n_cpus_big_jobs = int(args.n_cpus / len(sessions)) if args.n_cpus >= len(sessions) else int(args.n_cpus / 2)
    else:
        n_cpus_big_jobs = args.n_cpus
    n_cpus_big_jobs = 1 if n_cpus_big_jobs < 1 else n_cpus_big_jobs

    template_file = os.path.join(os.environ["FSLDIR"], "data/atlases", "JHU/JHU-ICBM-FA-1mm.nii.gz")

    if study == "lhab":
        masking_algo = "mrtrix"
    elif study == "camcan":
        masking_algo = "bet"
    elif study == "olm":
        masking_algo = "mrtrix"
    else:
        raise Exception("Study not known " + study)

    ########################
    # INPUT
    ########################
    if "{TotalReadoutTime}" in acq_str:
        use_json_file = True
    else:
        use_json_file = False

    if sessions:
        templates = {
            'dwi': 'sub-{subject_id}/ses-{session_id}/dwi/sub-{subject_id}_ses-{session_id}*_dwi.nii.gz',
            'bvec': 'sub-{subject_id}/ses-{session_id}/dwi/sub-{subject_id}_ses-{session_id}*_dwi.bvec',
            'bval': 'sub-{subject_id}/ses-{session_id}/dwi/sub-{subject_id}_ses-{session_id}*_dwi.bval',
        }
        if use_json_file:
            templates['json'] = 'sub-{subject_id}/ses-{session_id}/dwi/sub-{subject_id}_ses-{session_id}*_dwi.json'
    else:
        templates = {
            'dwi': 'sub-{subject_id}/dwi/sub-{subject_id}_*dwi.nii.gz{session_id}',  # session_id needed; "" is fed in
            'bvec': 'sub-{subject_id}/dwi/sub-{subject_id}_*dwi.bvec{session_id}',
            'bval': 'sub-{subject_id}/dwi/sub-{subject_id}_*dwi.bval{session_id}',
        }
        if use_json_file:
            templates['json'] = 'sub-{subject_id}/dwi/sub-{subject_id}_*dwi.json{session_id}'
        sessions = [""]

    sessions_interface = Node(IdentityInterface(fields=["session"]), "sessions_interface")
    sessions_interface.iterables = ("session", sessions)

    selectfiles = Node(nio.SelectFiles(templates,
                                       base_directory=args.bids_dir),
                       name="selectfiles")
    selectfiles.inputs.subject_id = subject
    wf.connect(sessions_interface, "session", selectfiles, "session_id")

    def format_subject_session_fct(subject, session=""):
        subject_label = "sub-" + subject
        session_label = "ses-" + session if session else ""
        subject_session_label = subject_label + ("_" + session_label if session_label else "")
        subject_session_prefix = subject_session_label + "_"
        subject_session_path = subject_label + ("/" + session_label if session_label else "")
        return subject_label, session_label, subject_session_label, subject_session_prefix, subject_session_path

    format_subject_session = Node(Function(input_names=["subject", "session"],
                                           output_names=["subject_label", "session_label", "subject_session_label",
                                                         "subject_session_prefix", "subject_session_path"],
                                           function=format_subject_session_fct), "format_subject_session")
    format_subject_session.inputs.subject = subject
    wf.connect(sessions_interface, "session", format_subject_session, "session")

    ########################
    # Set up outputs
    ########################
    sinker_preproc = Node(nio.DataSink(), name='sinker_preproc')
    sinker_preproc.inputs.base_directory = os.path.join(args.output_dir, "dwi_preprocessed")
    sinker_preproc.inputs.parameterization = False
    wf.connect(format_subject_session, 'subject_session_path', sinker_preproc, 'container')
    substitutions = [("_biascorr", ""),
                     ("_tensor", ""),
                     ('.eddy_rotated_bvecs', '.bvec'),
                     ('_acq-ap_run-1_dwi', ''),
                     ("_dwi", ""),
                     ("_b0s_mean_brain", "")
                     ]
    sinker_preproc.inputs.substitutions = substitutions

    sinker_plots = Node(nio.DataSink(), name='sinker_plots')
    sinker_plots.inputs.base_directory = args.output_dir
    sinker_plots.inputs.parameterization = False

    sinker_tract_plots = Node(nio.DataSink(), name='sinker_tract_plots')
    sinker_tract_plots.inputs.base_directory = args.output_dir
    sinker_tract_plots.inputs.parameterization = False

    sinker_extracted = Node(nio.DataSink(), name='sinker_extracted')
    sinker_extracted.inputs.base_directory = args.output_dir
    sinker_extracted.inputs.parameterization = False

    dwi_preprocessed = Node(IdentityInterface(fields=['dwi', 'mask', 'bvec', 'bval']), name='dwi_preprocessed')

    ########################
    # PREPROCESSING
    ########################
    # http://mrtrix.readthedocs.io/en/0.3.16/workflows/DWI_preprocessing_for_quantitative_analysis.html
    denoise = Node(Dwidenoise(), "denoise")
    wf.connect(selectfiles, "dwi", denoise, "in_file")
    wf.connect(denoise, "noise_file", sinker_preproc, "qa.@noise")

    prepare_eddy_textfiles = Node(interface=Function(input_names=["bval_file", "acq_str", "json_file"],
                                                     output_names=["acq_file", "index_file"],
                                                     function=prepare_eddy_textfiles_fct),
                                  name="prepare_eddy_textfiles")
    prepare_eddy_textfiles.inputs.acq_str = acq_str
    wf.connect(selectfiles, "bval", prepare_eddy_textfiles, "bval_file")
    if use_json_file:
        wf.connect(selectfiles, "json", prepare_eddy_textfiles, "json_file")

    init_mask_dil = Node(Dilatemask(), "init_mask_dil")

    if masking_algo == "mrtrix":
        init_mask = Node(Dwi2mask(), "init_mask")

        wf.connect(denoise, "out_file", init_mask, "in_file")
        wf.connect(selectfiles, "bvec", init_mask, "bvec")
        wf.connect(selectfiles, "bval", init_mask, "bval")

        wf.connect(init_mask, "out_mask_file", init_mask_dil, "in_file")
    elif masking_algo == "bet":
        init_mask = create_bet_mask_from_dwi(name="init_mask", do_realignment=True)

        wf.connect(denoise, "out_file", init_mask, "inputnode.dwi")
        wf.connect(selectfiles, "bvec", init_mask, "inputnode.bvec")
        wf.connect(selectfiles, "bval", init_mask, "inputnode.bval")

        wf.connect(init_mask, "outputnode.mask_file", init_mask_dil, "in_file")

    eddy = Node(fsl.Eddy(), "eddy")
    eddy.inputs.slm = "linear"
    eddy.inputs.repol = True
    eddy.inputs.num_threads = n_cpus_big_jobs
    wf.connect(prepare_eddy_textfiles, "acq_file", eddy, "in_acqp")
    wf.connect(prepare_eddy_textfiles, "index_file", eddy, "in_index")
    wf.connect(selectfiles, "bval", eddy, "in_bval")
    wf.connect(selectfiles, "bvec", eddy, "in_bvec")
    wf.connect(denoise, "out_file", eddy, "in_file")
    wf.connect(init_mask_dil, 'out_file', eddy, "in_mask")
    wf.connect(format_subject_session, 'subject_session_label', eddy, "out_base")

    bias = Node(Dwibiascorrect(), "bias")
    wf.connect(eddy, "out_corrected", bias, "in_file")
    wf.connect(selectfiles, "bval", bias, "bval")
    wf.connect(eddy, "out_rotated_bvecs", bias, "bvec")
    wf.connect(bias, "out_bias_file", sinker_preproc, "qa.@bias")

    if masking_algo == "mrtrix":
        mask = Node(Dwi2mask(), "mask")
        wf.connect(bias, "out_file", mask, "in_file")
        wf.connect(selectfiles, "bvec", mask, "bvec")
        wf.connect(selectfiles, "bval", mask, "bval")

    elif masking_algo == "bet":
        mask = create_bet_mask_from_dwi(name="mask", do_realignment=False)
        wf.connect(bias, "out_file", mask, "inputnode.dwi")
        wf.connect(selectfiles, "bvec", mask, "inputnode.bvec")
        wf.connect(selectfiles, "bval", mask, "inputnode.bval")

    # output eddy text files
    eddy_out = fsl.Eddy().output_spec.class_editable_traits()
    eddy_out = list(set(eddy_out) - {'out_corrected', 'out_rotated_bvecs'})
    for t in eddy_out:
        wf.connect(eddy, t, sinker_preproc, "dwi.eddy.@{}".format(t))

    def plot_motion_fnc(motion_file, subject_session):
        import os
        import pandas as pd
        import matplotlib as mpl
        mpl.use('Agg')
        import matplotlib.pyplot as plt

        df = pd.read_csv(motion_file, sep="  ", header=None,
                         names=["rms_movement_vs_first", "rms_movement_vs_previous"], engine='python')
        df.plot(title=subject_session)
        out_file = os.path.abspath(subject_session + "_motion.pdf")
        plt.savefig(out_file)
        return out_file

    motion_plot = Node(Function(input_names=["motion_file", "subject_session"], output_names=["out_file"],
                                function=plot_motion_fnc),
                       "motion_plot")
    wf.connect(eddy, "out_restricted_movement_rms", motion_plot, "motion_file")
    wf.connect(format_subject_session, "subject_session_label", motion_plot, "subject_session")
    wf.connect(motion_plot, "out_file", sinker_plots, "motion")

    wf.connect(bias, "out_file", dwi_preprocessed, "dwi")
    wf.connect(eddy, "out_rotated_bvecs", dwi_preprocessed, "bvec")
    wf.connect(selectfiles, "bval", dwi_preprocessed, "bval")
    if masking_algo == "mrtrix":
        wf.connect(mask, "out_mask_file", dwi_preprocessed, "mask")
    elif masking_algo == "bet":
        wf.connect(mask, "outputnode.mask_file", dwi_preprocessed, "mask")

    wf.connect(dwi_preprocessed, "dwi", sinker_preproc, "dwi.@dwi")
    wf.connect(dwi_preprocessed, "bvec", sinker_preproc, "dwi.@bvec")
    wf.connect(dwi_preprocessed, "bval", sinker_preproc, "dwi.@bval")
    wf.connect(dwi_preprocessed, "mask", sinker_preproc, "dwi.@mask")

    ########################
    # Tensor fit
    ########################
    # mrtrix tensor fit
    tensor = Node(Dwi2tensor(), "tensor")
    wf.connect(dwi_preprocessed, "dwi", tensor, "in_file")
    wf.connect(dwi_preprocessed, 'mask', tensor, 'mask_file')
    wf.connect(dwi_preprocessed, 'bvec', tensor, 'bvec')
    wf.connect(dwi_preprocessed, 'bval', tensor, 'bval')

    tensor_metrics = Node(Tensor2metric(), "tensor_metrics")
    wf.connect(tensor, "out_file", tensor_metrics, "in_file")
    for t in Tensor2metric().output_spec.class_editable_traits():
        wf.connect(tensor_metrics, t, sinker_preproc, "tensor_metrics.@{}".format(t))

    ########################
    # MNI
    ########################
    # # ANTS REG
    ants_reg = Node(AntsRegistrationSynQuick() if ants_quick else AntsRegistrationSyn(), "ants_reg")
    wf.connect(tensor_metrics, "out_file_fa", ants_reg, "in_file")
    ants_reg.inputs.template_file = template_file

    ants_reg.inputs.num_threads = n_cpus_big_jobs
    wf.connect(format_subject_session, "subject_session_prefix", ants_reg, "output_prefix")

    wf.connect(ants_reg, "out_matrix", sinker_preproc, "mni_transformation.@out_matrix")
    wf.connect(ants_reg, "forward_warp_field", sinker_preproc, "mni_transformation.@forward_warp_field")

    def make_trasform_list_fct(linear, warp):
        return [warp, linear]

    make_transform_list = Node(Function(input_names=["linear", "warp"],
                                        output_names=["out_list"],
                                        function=make_trasform_list_fct),
                               "make_transform_list")
    wf.connect(ants_reg, "out_matrix", make_transform_list, "linear")
    wf.connect(ants_reg, "forward_warp_field", make_transform_list, "warp")

    # now transform all metrics to MNI
    transform_fa = Node(ants.resampling.ApplyTransforms(), "transform_fa")
    transform_fa.inputs.out_postfix = "_mni"
    transform_fa.inputs.reference_image = template_file
    wf.connect(make_transform_list, "out_list", transform_fa, "transforms")
    wf.connect(tensor_metrics, "out_file_fa", transform_fa, "input_image")
    wf.connect(transform_fa, "output_image", sinker_preproc, "tensor_metrics_mni.@transform_fa")

    transform_md = Node(ants.resampling.ApplyTransforms(), "transform_md")
    transform_md.inputs.out_postfix = "_mni"
    transform_md.inputs.reference_image = template_file
    wf.connect(make_transform_list, "out_list", transform_md, "transforms")
    wf.connect(tensor_metrics, "out_file_md", transform_md, "input_image")
    wf.connect(transform_md, "output_image", sinker_preproc, "tensor_metrics_mni.@transform_md")

    transform_ad = Node(ants.resampling.ApplyTransforms(), "transform_ad")
    transform_ad.inputs.out_postfix = "_mni"
    transform_ad.inputs.reference_image = template_file
    wf.connect(make_transform_list, "out_list", transform_ad, "transforms")
    wf.connect(tensor_metrics, "out_file_ad", transform_ad, "input_image")
    wf.connect(transform_ad, "output_image", sinker_preproc, "tensor_metrics_mni.@transform_ad")

    transform_rd = Node(ants.resampling.ApplyTransforms(), "transform_rd")
    transform_rd.inputs.out_postfix = "_mni"
    transform_rd.inputs.reference_image = template_file
    wf.connect(make_transform_list, "out_list", transform_rd, "transforms")
    wf.connect(tensor_metrics, "out_file_rd", transform_rd, "input_image")
    wf.connect(transform_rd, "output_image", sinker_preproc, "tensor_metrics_mni.@transform_rd")

    def reg_plot_fct(in_file, template_file, subject_session):
        from nilearn import plotting
        import os
        out_file_reg = os.path.abspath(subject_session + "_reg.pdf")
        display = plotting.plot_anat(in_file, title=subject_session)
        display.add_edges(template_file)
        display.savefig(out_file_reg)
        return out_file_reg

    def tract_plot_fct(in_file, subject_session, atlas):
        from nilearn import plotting
        import os

        if atlas == "JHU25":
            thr = 25
        elif atlas == "JHU50":
            thr = 50
        else:
            raise Exception("Atlas unknown " + atlas)

        atlas_file = os.path.join(os.environ["FSLDIR"], "data/atlases",
                                  "JHU/JHU-ICBM-tracts-maxprob-thr{}-1mm.nii.gz".format(thr))

        out_file_tract = os.path.abspath(subject_session + "_atlas-{}_tract.pdf".format(atlas))
        display = plotting.plot_anat(in_file, title=subject_session)
        display.add_contours(atlas_file)
        display.savefig(out_file_tract)
        return out_file_tract

    atlas_interface = Node(IdentityInterface(fields=["atlas"]), "atlas_interface")
    atlas_interface.iterables = ("atlas", ["JHU25", "JHU50"])

    reg_plot = Node(Function(input_names=["in_file", "template_file", "subject_session"],
                             output_names=["out_file_reg"],
                             function=reg_plot_fct),
                    "reg_plot")
    wf.connect(transform_fa, "output_image", reg_plot, "in_file")
    reg_plot.inputs.template_file = template_file
    wf.connect(format_subject_session, "subject_session_label", reg_plot, "subject_session")
    wf.connect(reg_plot, "out_file_reg", sinker_plots, "regplots")

    tract_plot = Node(Function(input_names=["in_file", "subject_session", "atlas"],
                               output_names=["out_file_tract"],
                               function=tract_plot_fct),
                      "tract_plot")
    wf.connect(transform_fa, "output_image", tract_plot, "in_file")
    wf.connect(format_subject_session, "subject_session_label", tract_plot, "subject_session")
    wf.connect(atlas_interface, "atlas", tract_plot, "atlas")
    wf.connect(tract_plot, "out_file_tract", sinker_tract_plots, "tractplots")

    def concat_filenames_fct(in_file_fa, in_file_md, in_file_ad, in_file_rd):
        return [in_file_fa, in_file_md, in_file_ad, in_file_rd]

    concat_filenames = Node(Function(input_names=["in_file_fa", "in_file_md", "in_file_ad", "in_file_rd"],
                                     output_names=["out_list"],
                                     function=concat_filenames_fct),
                            "concat_filenames")
    metrics_labels = ["fa", "md", "ad", "rd"]
    wf.connect(transform_fa, "output_image", concat_filenames, "in_file_fa")
    wf.connect(transform_md, "output_image", concat_filenames, "in_file_md")
    wf.connect(transform_ad, "output_image", concat_filenames, "in_file_ad")
    wf.connect(transform_rd, "output_image", concat_filenames, "in_file_rd")

    merge = Node(fsl.Merge(), "merge")
    merge.inputs.dimension = "t"
    wf.connect(concat_filenames, "out_list", merge, "in_files")

    # create an fa mask
    fa_mask = Node(fsl.Threshold(), "fa_mask")
    wf.connect(transform_fa, "output_image", fa_mask, "in_file")
    fa_mask.inputs.thresh = 0.2
    fa_mask.inputs.args = "-bin"

    merge_masked = Node(fsl.ApplyMask(), "merge_masked")
    wf.connect(merge, "merged_file", merge_masked, "in_file")
    wf.connect(fa_mask, "out_file", merge_masked, "mask_file")

    extract = Node(Function(input_names=["in_file", "metric_labels", "subject", "session", "atlas"],
                            output_names=["out_file"],
                            function=extract_jhu),
                   "extract")
    extract.inputs.subject = subject
    extract.inputs.metric_labels = metrics_labels
    wf.connect(sessions_interface, "session", extract, "session")
    wf.connect(merge_masked, "out_file", extract, "in_file")
    wf.connect(atlas_interface, "atlas", extract, "atlas")

    wf.connect(extract, "out_file", sinker_extracted, "extracted_metrics")

    return wf