Exemple #1
0
def create_machine_learning_workflow(name="CreateEdgeProbabilityMap",
                                     resample=True,
                                     plugin_args=None):
    """
    This function...
    :param name:
    :param resample:
    :param plugin_args:
    :return:
    """
    workflow = Workflow(name)
    input_spec = Node(IdentityInterface([
        "rho", "phi", "theta", "posteriors", "t1_file", "acpc_transform",
        "gm_classifier_file", "wm_classifier_file"
    ]),
                      name="input_spec")

    predict_edge_probability = Node(PredictEdgeProbability(),
                                    name="PredictEdgeProbability")
    if plugin_args:
        predict_edge_probability.plugin_args = plugin_args
    workflow.connect([(input_spec, predict_edge_probability,
                       [("t1_file", "t1_file"),
                        ("gm_classifier_file", "gm_classifier_file"),
                        ("wm_classifier_file", "wm_classifier_file")])])

    if resample:
        collect_features = Node(CollectFeatureFiles(),
                                name="CollectFeatureFiles")
        collect_features.inputs.inverse_transform = True
        workflow.connect([(input_spec, collect_features,
                           [("rho", "rho"), ("phi", "phi"), ("theta", "theta"),
                            ("posteriors", "posterior_files"),
                            ("t1_file", "reference_file"),
                            ("acpc_transform", "transform_file")])])

        workflow.connect([(collect_features, predict_edge_probability,
                           [("feature_files", "additional_files")])])
    else:
        print("workflow not yet created")
        # TODO: create workflow that does not resample the input images
        return

    output_spec = Node(IdentityInterface(
        ["gm_probability_map", "wm_probability_map"]),
                       name="output_spec")
    workflow.connect(predict_edge_probability, "gm_edge_probability",
                     output_spec, "gm_probability_map")
    workflow.connect(predict_edge_probability, "wm_edge_probability",
                     output_spec, "wm_probability_map")

    return workflow
def create_workflow(files,
                    target_file,
                    subject_id,
                    TR,
                    slice_times,
                    norm_threshold=1,
                    num_components=5,
                    vol_fwhm=None,
                    surf_fwhm=None,
                    lowpass_freq=-1,
                    highpass_freq=-1,
                    subjects_dir=None,
                    sink_directory=os.getcwd(),
                    target_subject=['fsaverage3', 'fsaverage4'],
                    name='resting'):

    wf = Workflow(name=name)

    # Rename files in case they are named identically
    name_unique = MapNode(Rename(format_string='rest_%(run)02d'),
                          iterfield=['in_file', 'run'],
                          name='rename')
    name_unique.inputs.keep_ext = True
    name_unique.inputs.run = list(range(1, len(files) + 1))
    name_unique.inputs.in_file = files

    realign = Node(nipy.SpaceTimeRealigner(), name="spacetime_realign")
    realign.inputs.slice_times = slice_times
    realign.inputs.tr = TR
    realign.inputs.slice_info = 2
    realign.plugin_args = {'sbatch_args': '-c%d' % 4}

    # Compute TSNR on realigned data regressing polynomials up to order 2
    tsnr = MapNode(TSNR(regress_poly=2), iterfield=['in_file'], name='tsnr')
    wf.connect(realign, "out_file", tsnr, "in_file")

    # Compute the median image across runs
    calc_median = Node(Function(input_names=['in_files'],
                                output_names=['median_file'],
                                function=median,
                                imports=imports),
                       name='median')
    wf.connect(tsnr, 'detrended_file', calc_median, 'in_files')

    """Segment and Register
    """

    registration = create_reg_workflow(name='registration')
    wf.connect(calc_median, 'median_file', registration, 'inputspec.mean_image')
    registration.inputs.inputspec.subject_id = subject_id
    registration.inputs.inputspec.subjects_dir = subjects_dir
    registration.inputs.inputspec.target_image = target_file

    """Quantify TSNR in each freesurfer ROI
    """

    get_roi_tsnr = MapNode(fs.SegStats(default_color_table=True),
                           iterfield=['in_file'], name='get_aparc_tsnr')
    get_roi_tsnr.inputs.avgwf_txt_file = True
    wf.connect(tsnr, 'tsnr_file', get_roi_tsnr, 'in_file')
    wf.connect(registration, 'outputspec.aparc', get_roi_tsnr, 'segmentation_file')

    """Use :class:`nipype.algorithms.rapidart` to determine which of the
    images in the functional series are outliers based on deviations in
    intensity or movement.
    """

    art = Node(interface=ArtifactDetect(), name="art")
    art.inputs.use_differences = [True, True]
    art.inputs.use_norm = True
    art.inputs.norm_threshold = norm_threshold
    art.inputs.zintensity_threshold = 9
    art.inputs.mask_type = 'spm_global'
    art.inputs.parameter_source = 'NiPy'

    """Here we are connecting all the nodes together. Notice that we add the merge node only if you choose
    to use 4D. Also `get_vox_dims` function is passed along the input volume of normalise to set the optimal
    voxel sizes.
    """

    wf.connect([(name_unique, realign, [('out_file', 'in_file')]),
                (realign, art, [('out_file', 'realigned_files')]),
                (realign, art, [('par_file', 'realignment_parameters')]),
                ])

    def selectindex(files, idx):
        import numpy as np
        from nipype.utils.filemanip import filename_to_list, list_to_filename
        return list_to_filename(np.array(filename_to_list(files))[idx].tolist())

    mask = Node(fsl.BET(), name='getmask')
    mask.inputs.mask = True
    wf.connect(calc_median, 'median_file', mask, 'in_file')
    # get segmentation in normalized functional space

    def merge_files(in1, in2):
        out_files = filename_to_list(in1)
        out_files.extend(filename_to_list(in2))
        return out_files

    # filter some noise

    # Compute motion regressors
    motreg = Node(Function(input_names=['motion_params', 'order',
                                        'derivatives'],
                           output_names=['out_files'],
                           function=motion_regressors,
                           imports=imports),
                  name='getmotionregress')
    wf.connect(realign, 'par_file', motreg, 'motion_params')

    # Create a filter to remove motion and art confounds
    createfilter1 = Node(Function(input_names=['motion_params', 'comp_norm',
                                               'outliers', 'detrend_poly'],
                                  output_names=['out_files'],
                                  function=build_filter1,
                                  imports=imports),
                         name='makemotionbasedfilter')
    createfilter1.inputs.detrend_poly = 2
    wf.connect(motreg, 'out_files', createfilter1, 'motion_params')
    wf.connect(art, 'norm_files', createfilter1, 'comp_norm')
    wf.connect(art, 'outlier_files', createfilter1, 'outliers')

    filter1 = MapNode(fsl.GLM(out_f_name='F_mcart.nii.gz',
                              out_pf_name='pF_mcart.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design', 'out_res_name'],
                      name='filtermotion')

    wf.connect(realign, 'out_file', filter1, 'in_file')
    wf.connect(realign, ('out_file', rename, '_filtermotart'),
               filter1, 'out_res_name')
    wf.connect(createfilter1, 'out_files', filter1, 'design')

    createfilter2 = MapNode(ACompCor(),
                            iterfield=['realigned_file', 'extra_regressors'],
                            name='makecompcorrfilter')
    createfilter2.inputs.components_file = 'noise_components.txt'
    createfilter2.inputs.num_components = num_components

    wf.connect(createfilter1, 'out_files', createfilter2, 'extra_regressors')
    wf.connect(filter1, 'out_res', createfilter2, 'realigned_file')
    wf.connect(registration, ('outputspec.segmentation_files', selectindex, [0, 2]),
               createfilter2, 'mask_file')

    filter2 = MapNode(fsl.GLM(out_f_name='F.nii.gz',
                              out_pf_name='pF.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design', 'out_res_name'],
                      name='filter_noise_nosmooth')
    wf.connect(filter1, 'out_res', filter2, 'in_file')
    wf.connect(filter1, ('out_res', rename, '_cleaned'),
               filter2, 'out_res_name')
    wf.connect(createfilter2, 'components_file', filter2, 'design')
    wf.connect(mask, 'mask_file', filter2, 'mask')

    bandpass = Node(Function(input_names=['files', 'lowpass_freq',
                                          'highpass_freq', 'fs'],
                             output_names=['out_files'],
                             function=bandpass_filter,
                             imports=imports),
                    name='bandpass_unsmooth')
    bandpass.inputs.fs = 1. / TR
    bandpass.inputs.highpass_freq = highpass_freq
    bandpass.inputs.lowpass_freq = lowpass_freq
    wf.connect(filter2, 'out_res', bandpass, 'files')

    """Smooth the functional data using
    :class:`nipype.interfaces.fsl.IsotropicSmooth`.
    """

    smooth = MapNode(interface=fsl.IsotropicSmooth(), name="smooth", iterfield=["in_file"])
    smooth.inputs.fwhm = vol_fwhm

    wf.connect(bandpass, 'out_files', smooth, 'in_file')

    collector = Node(Merge(2), name='collect_streams')
    wf.connect(smooth, 'out_file', collector, 'in1')
    wf.connect(bandpass, 'out_files', collector, 'in2')

    """
    Transform the remaining images. First to anatomical and then to target
    """

    warpall = MapNode(ants.ApplyTransforms(), iterfield=['input_image'],
                      name='warpall')
    warpall.inputs.input_image_type = 3
    warpall.inputs.interpolation = 'Linear'
    warpall.inputs.invert_transform_flags = [False, False]
    warpall.inputs.terminal_output = 'file'
    warpall.inputs.reference_image = target_file
    warpall.inputs.args = '--float'
    warpall.inputs.num_threads = 2
    warpall.plugin_args = {'sbatch_args': '-c%d' % 2}

    # transform to target
    wf.connect(collector, 'out', warpall, 'input_image')
    wf.connect(registration, 'outputspec.transforms', warpall, 'transforms')

    mask_target = Node(fsl.ImageMaths(op_string='-bin'), name='target_mask')

    wf.connect(registration, 'outputspec.anat2target', mask_target, 'in_file')

    maskts = MapNode(fsl.ApplyMask(), iterfield=['in_file'], name='ts_masker')
    wf.connect(warpall, 'output_image', maskts, 'in_file')
    wf.connect(mask_target, 'out_file', maskts, 'mask_file')

    # map to surface
    # extract aparc+aseg ROIs
    # extract subcortical ROIs
    # extract target space ROIs
    # combine subcortical and cortical rois into a single cifti file

    #######
    # Convert aparc to subject functional space

    # Sample the average time series in aparc ROIs
    sampleaparc = MapNode(freesurfer.SegStats(default_color_table=True),
                          iterfield=['in_file', 'summary_file',
                                     'avgwf_txt_file'],
                          name='aparc_ts')
    sampleaparc.inputs.segment_id = ([8] + list(range(10, 14)) + [17, 18, 26, 47] +
                                     list(range(49, 55)) + [58] + list(range(1001, 1036)) +
                                     list(range(2001, 2036)))

    wf.connect(registration, 'outputspec.aparc',
               sampleaparc, 'segmentation_file')
    wf.connect(collector, 'out', sampleaparc, 'in_file')

    def get_names(files, suffix):
        """Generate appropriate names for output files
        """
        from nipype.utils.filemanip import (split_filename, filename_to_list,
                                            list_to_filename)
        import os
        out_names = []
        for filename in files:
            path, name, _ = split_filename(filename)
            out_names.append(os.path.join(path, name + suffix))
        return list_to_filename(out_names)

    wf.connect(collector, ('out', get_names, '_avgwf.txt'),
               sampleaparc, 'avgwf_txt_file')
    wf.connect(collector, ('out', get_names, '_summary.stats'),
               sampleaparc, 'summary_file')

    # Sample the time series onto the surface of the target surface. Performs
    # sampling into left and right hemisphere
    target = Node(IdentityInterface(fields=['target_subject']), name='target')
    target.iterables = ('target_subject', filename_to_list(target_subject))

    samplerlh = MapNode(freesurfer.SampleToSurface(),
                        iterfield=['source_file'],
                        name='sampler_lh')
    samplerlh.inputs.sampling_method = "average"
    samplerlh.inputs.sampling_range = (0.1, 0.9, 0.1)
    samplerlh.inputs.sampling_units = "frac"
    samplerlh.inputs.interp_method = "trilinear"
    samplerlh.inputs.smooth_surf = surf_fwhm
    # samplerlh.inputs.cortex_mask = True
    samplerlh.inputs.out_type = 'niigz'
    samplerlh.inputs.subjects_dir = subjects_dir

    samplerrh = samplerlh.clone('sampler_rh')

    samplerlh.inputs.hemi = 'lh'
    wf.connect(collector, 'out', samplerlh, 'source_file')
    wf.connect(registration, 'outputspec.out_reg_file', samplerlh, 'reg_file')
    wf.connect(target, 'target_subject', samplerlh, 'target_subject')

    samplerrh.set_input('hemi', 'rh')
    wf.connect(collector, 'out', samplerrh, 'source_file')
    wf.connect(registration, 'outputspec.out_reg_file', samplerrh, 'reg_file')
    wf.connect(target, 'target_subject', samplerrh, 'target_subject')

    # Combine left and right hemisphere to text file
    combiner = MapNode(Function(input_names=['left', 'right'],
                                output_names=['out_file'],
                                function=combine_hemi,
                                imports=imports),
                       iterfield=['left', 'right'],
                       name="combiner")
    wf.connect(samplerlh, 'out_file', combiner, 'left')
    wf.connect(samplerrh, 'out_file', combiner, 'right')

    # Sample the time series file for each subcortical roi
    ts2txt = MapNode(Function(input_names=['timeseries_file', 'label_file',
                                           'indices'],
                              output_names=['out_file'],
                              function=extract_subrois,
                              imports=imports),
                     iterfield=['timeseries_file'],
                     name='getsubcortts')
    ts2txt.inputs.indices = [8] + list(range(10, 14)) + [17, 18, 26, 47] +\
        list(range(49, 55)) + [58]
    ts2txt.inputs.label_file = \
        os.path.abspath(('OASIS-TRT-20_jointfusion_DKT31_CMA_labels_in_MNI152_'
                         '2mm_v2.nii.gz'))
    wf.connect(maskts, 'out_file', ts2txt, 'timeseries_file')

    ######

    substitutions = [('_target_subject_', ''),
                     ('_filtermotart_cleaned_bp_trans_masked', ''),
                     ('_filtermotart_cleaned_bp', ''),
                     ]
    substitutions += [("_smooth%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_ts_masker%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_getsubcortts%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_combiner%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_filtermotion%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_filter_noise_nosmooth%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_makecompcorfilter%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_get_aparc_tsnr%d/" % i, "run%d_" % (i + 1)) for i in range(11)[::-1]]

    substitutions += [("T1_out_brain_pve_0_maths_warped", "compcor_csf"),
                      ("T1_out_brain_pve_1_maths_warped", "compcor_gm"),
                      ("T1_out_brain_pve_2_maths_warped", "compcor_wm"),
                      ("output_warped_image_maths", "target_brain_mask"),
                      ("median_brain_mask", "native_brain_mask"),
                      ("corr_", "")]

    regex_subs = [('_combiner.*/sar', '/smooth/'),
                  ('_combiner.*/ar', '/unsmooth/'),
                  ('_aparc_ts.*/sar', '/smooth/'),
                  ('_aparc_ts.*/ar', '/unsmooth/'),
                  ('_getsubcortts.*/sar', '/smooth/'),
                  ('_getsubcortts.*/ar', '/unsmooth/'),
                  ('series/sar', 'series/smooth/'),
                  ('series/ar', 'series/unsmooth/'),
                  ('_inverse_transform./', ''),
                  ]
    # Save the relevant data into an output directory
    datasink = Node(interface=DataSink(), name="datasink")
    datasink.inputs.base_directory = sink_directory
    datasink.inputs.container = subject_id
    datasink.inputs.substitutions = substitutions
    datasink.inputs.regexp_substitutions = regex_subs  # (r'(/_.*(\d+/))', r'/run\2')
    wf.connect(realign, 'par_file', datasink, 'resting.qa.motion')
    wf.connect(art, 'norm_files', datasink, 'resting.qa.art.@norm')
    wf.connect(art, 'intensity_files', datasink, 'resting.qa.art.@intensity')
    wf.connect(art, 'outlier_files', datasink, 'resting.qa.art.@outlier_files')
    wf.connect(registration, 'outputspec.segmentation_files', datasink, 'resting.mask_files')
    wf.connect(registration, 'outputspec.anat2target', datasink, 'resting.qa.ants')
    wf.connect(mask, 'mask_file', datasink, 'resting.mask_files.@brainmask')
    wf.connect(mask_target, 'out_file', datasink, 'resting.mask_files.target')
    wf.connect(filter1, 'out_f', datasink, 'resting.qa.compmaps.@mc_F')
    wf.connect(filter1, 'out_pf', datasink, 'resting.qa.compmaps.@mc_pF')
    wf.connect(filter2, 'out_f', datasink, 'resting.qa.compmaps')
    wf.connect(filter2, 'out_pf', datasink, 'resting.qa.compmaps.@p')
    wf.connect(registration, 'outputspec.min_cost_file', datasink, 'resting.qa.mincost')
    wf.connect(tsnr, 'tsnr_file', datasink, 'resting.qa.tsnr.@map')
    wf.connect([(get_roi_tsnr, datasink, [('avgwf_txt_file', 'resting.qa.tsnr'),
                                          ('summary_file', 'resting.qa.tsnr.@summary')])])

    wf.connect(bandpass, 'out_files', datasink, 'resting.timeseries.@bandpassed')
    wf.connect(smooth, 'out_file', datasink, 'resting.timeseries.@smoothed')
    wf.connect(createfilter1, 'out_files',
               datasink, 'resting.regress.@regressors')
    wf.connect(createfilter2, 'components_file',
               datasink, 'resting.regress.@compcorr')
    wf.connect(maskts, 'out_file', datasink, 'resting.timeseries.target')
    wf.connect(sampleaparc, 'summary_file',
               datasink, 'resting.parcellations.aparc')
    wf.connect(sampleaparc, 'avgwf_txt_file',
               datasink, 'resting.parcellations.aparc.@avgwf')
    wf.connect(ts2txt, 'out_file',
               datasink, 'resting.parcellations.grayo.@subcortical')

    datasink2 = Node(interface=DataSink(), name="datasink2")
    datasink2.inputs.base_directory = sink_directory
    datasink2.inputs.container = subject_id
    datasink2.inputs.substitutions = substitutions
    datasink2.inputs.regexp_substitutions = regex_subs  # (r'(/_.*(\d+/))', r'/run\2')
    wf.connect(combiner, 'out_file',
               datasink2, 'resting.parcellations.grayo.@surface')
    return wf
def create_reg_workflow(name='registration'):
    """Create a FEAT preprocessing workflow together with freesurfer

    Parameters
    ----------

        name : name of workflow (default: 'registration')

    Inputs::

        inputspec.source_files : files (filename or list of filenames to register)
        inputspec.mean_image : reference image to use
        inputspec.anatomical_image : anatomical image to coregister to
        inputspec.target_image : registration target

    Outputs::

        outputspec.func2anat_transform : FLIRT transform
        outputspec.anat2target_transform : FLIRT+FNIRT transform
        outputspec.transformed_files : transformed files in target space
        outputspec.transformed_mean : mean image in target space
    """

    register = Workflow(name=name)

    inputnode = Node(interface=IdentityInterface(fields=['source_files',
                                                         'mean_image',
                                                         'subject_id',
                                                         'subjects_dir',
                                                         'target_image']),
                     name='inputspec')

    outputnode = Node(interface=IdentityInterface(fields=['func2anat_transform',
                                                          'out_reg_file',
                                                          'anat2target_transform',
                                                          'transforms',
                                                          'transformed_mean',
                                                          'segmentation_files',
                                                          'anat2target',
                                                          'aparc',
                                                          'min_cost_file'
                                                          ]),
                      name='outputspec')

    # Get the subject's freesurfer source directory
    fssource = Node(FreeSurferSource(),
                    name='fssource')
    fssource.run_without_submitting = True
    register.connect(inputnode, 'subject_id', fssource, 'subject_id')
    register.connect(inputnode, 'subjects_dir', fssource, 'subjects_dir')

    convert = Node(freesurfer.MRIConvert(out_type='nii'),
                   name="convert")
    register.connect(fssource, 'T1', convert, 'in_file')

    # Coregister the median to the surface
    bbregister = Node(freesurfer.BBRegister(),
                      name='bbregister')
    bbregister.inputs.init = 'fsl'
    bbregister.inputs.contrast_type = 't2'
    bbregister.inputs.out_fsl_file = True
    bbregister.inputs.epi_mask = True
    register.connect(inputnode, 'subject_id', bbregister, 'subject_id')
    register.connect(inputnode, 'mean_image', bbregister, 'source_file')
    register.connect(inputnode, 'subjects_dir', bbregister, 'subjects_dir')

    """
    Estimate the tissue classes from the anatomical image. But use aparc+aseg's brain mask
    """

    binarize = Node(fs.Binarize(min=0.5, out_type="nii.gz", dilate=1), name="binarize_aparc")
    register.connect(fssource, ("aparc_aseg", get_aparc_aseg), binarize, "in_file")
    stripper = Node(fsl.ApplyMask(), name='stripper')
    register.connect(binarize, "binary_file", stripper, "mask_file")
    register.connect(convert, 'out_file', stripper, 'in_file')

    fast = Node(fsl.FAST(), name='fast')
    register.connect(stripper, 'out_file', fast, 'in_files')

    """
    Binarize the segmentation
    """

    binarize = MapNode(fsl.ImageMaths(op_string='-nan -thr 0.9 -ero -bin'),
                       iterfield=['in_file'],
                       name='binarize')
    register.connect(fast, 'partial_volume_files', binarize, 'in_file')

    """
    Apply inverse transform to take segmentations to functional space
    """

    applyxfm = MapNode(freesurfer.ApplyVolTransform(inverse=True,
                                                    interp='nearest'),
                       iterfield=['target_file'],
                       name='inverse_transform')
    register.connect(inputnode, 'subjects_dir', applyxfm, 'subjects_dir')
    register.connect(bbregister, 'out_reg_file', applyxfm, 'reg_file')
    register.connect(binarize, 'out_file', applyxfm, 'target_file')
    register.connect(inputnode, 'mean_image', applyxfm, 'source_file')

    """
    Apply inverse transform to aparc file
    """

    aparcxfm = Node(freesurfer.ApplyVolTransform(inverse=True,
                                                 interp='nearest'),
                    name='aparc_inverse_transform')
    register.connect(inputnode, 'subjects_dir', aparcxfm, 'subjects_dir')
    register.connect(bbregister, 'out_reg_file', aparcxfm, 'reg_file')
    register.connect(fssource, ('aparc_aseg', get_aparc_aseg),
                     aparcxfm, 'target_file')
    register.connect(inputnode, 'mean_image', aparcxfm, 'source_file')

    """
    Convert the BBRegister transformation to ANTS ITK format
    """

    convert2itk = Node(C3dAffineTool(), name='convert2itk')
    convert2itk.inputs.fsl2ras = True
    convert2itk.inputs.itk_transform = True
    register.connect(bbregister, 'out_fsl_file', convert2itk, 'transform_file')
    register.connect(inputnode, 'mean_image', convert2itk, 'source_file')
    register.connect(stripper, 'out_file', convert2itk, 'reference_file')

    """
    Compute registration between the subject's structural and MNI template
    This is currently set to perform a very quick registration. However, the
    registration can be made significantly more accurate for cortical
    structures by increasing the number of iterations
    All parameters are set using the example from:
    #https://github.com/stnava/ANTs/blob/master/Scripts/newAntsExample.sh
    """

    reg = Node(ants.Registration(), name='antsRegister')
    reg.inputs.output_transform_prefix = "output_"
    reg.inputs.transforms = ['Rigid', 'Affine', 'SyN']
    reg.inputs.transform_parameters = [(0.1,), (0.1,), (0.2, 3.0, 0.0)]
    reg.inputs.number_of_iterations = [[10000, 11110, 11110]] * 2 + [[100, 30, 20]]
    reg.inputs.dimension = 3
    reg.inputs.write_composite_transform = True
    reg.inputs.collapse_output_transforms = True
    reg.inputs.initial_moving_transform_com = True
    reg.inputs.metric = ['Mattes'] * 2 + [['Mattes', 'CC']]
    reg.inputs.metric_weight = [1] * 2 + [[0.5, 0.5]]
    reg.inputs.radius_or_number_of_bins = [32] * 2 + [[32, 4]]
    reg.inputs.sampling_strategy = ['Regular'] * 2 + [[None, None]]
    reg.inputs.sampling_percentage = [0.3] * 2 + [[None, None]]
    reg.inputs.convergence_threshold = [1.e-8] * 2 + [-0.01]
    reg.inputs.convergence_window_size = [20] * 2 + [5]
    reg.inputs.smoothing_sigmas = [[4, 2, 1]] * 2 + [[1, 0.5, 0]]
    reg.inputs.sigma_units = ['vox'] * 3
    reg.inputs.shrink_factors = [[3, 2, 1]] * 2 + [[4, 2, 1]]
    reg.inputs.use_estimate_learning_rate_once = [True] * 3
    reg.inputs.use_histogram_matching = [False] * 2 + [True]
    reg.inputs.winsorize_lower_quantile = 0.005
    reg.inputs.winsorize_upper_quantile = 0.995
    reg.inputs.float = True
    reg.inputs.output_warped_image = 'output_warped_image.nii.gz'
    reg.inputs.num_threads = 4
    reg.plugin_args = {'sbatch_args': '-c%d' % 4}
    register.connect(stripper, 'out_file', reg, 'moving_image')
    register.connect(inputnode, 'target_image', reg, 'fixed_image')

    """
    Concatenate the affine and ants transforms into a list
    """

    merge = Node(Merge(2), iterfield=['in2'], name='mergexfm')
    register.connect(convert2itk, 'itk_transform', merge, 'in2')
    register.connect(reg, 'composite_transform', merge, 'in1')

    """
    Transform the mean image. First to anatomical and then to target
    """

    warpmean = Node(ants.ApplyTransforms(), name='warpmean')
    warpmean.inputs.input_image_type = 3
    warpmean.inputs.interpolation = 'Linear'
    warpmean.inputs.invert_transform_flags = [False, False]
    warpmean.inputs.terminal_output = 'file'
    warpmean.inputs.args = '--float'
    warpmean.inputs.num_threads = 4
    warpmean.plugin_args = {'sbatch_args': '-c%d' % 4}

    register.connect(inputnode, 'target_image', warpmean, 'reference_image')
    register.connect(inputnode, 'mean_image', warpmean, 'input_image')
    register.connect(merge, 'out', warpmean, 'transforms')

    """
    Assign all the output files
    """

    register.connect(reg, 'warped_image', outputnode, 'anat2target')
    register.connect(warpmean, 'output_image', outputnode, 'transformed_mean')
    register.connect(applyxfm, 'transformed_file',
                     outputnode, 'segmentation_files')
    register.connect(aparcxfm, 'transformed_file',
                     outputnode, 'aparc')
    register.connect(bbregister, 'out_fsl_file',
                     outputnode, 'func2anat_transform')
    register.connect(bbregister, 'out_reg_file',
                     outputnode, 'out_reg_file')
    register.connect(reg, 'composite_transform',
                     outputnode, 'anat2target_transform')
    register.connect(merge, 'out', outputnode, 'transforms')
    register.connect(bbregister, 'min_cost_file',
                     outputnode, 'min_cost_file')

    return register
def create_reg_workflow(name='registration'):
    """Create a FEAT preprocessing workflow together with freesurfer

    Parameters
    ----------

        name : name of workflow (default: 'registration')

    Inputs::

        inputspec.source_files : files (filename or list of filenames to register)
        inputspec.mean_image : reference image to use
        inputspec.anatomical_image : anatomical image to coregister to
        inputspec.target_image : registration target

    Outputs::

        outputspec.func2anat_transform : FLIRT transform
        outputspec.anat2target_transform : FLIRT+FNIRT transform
        outputspec.transformed_files : transformed files in target space
        outputspec.transformed_mean : mean image in target space
    """

    register = Workflow(name=name)

    inputnode = Node(interface=IdentityInterface(fields=[
        'source_files', 'mean_image', 'subject_id', 'subjects_dir',
        'target_image'
    ]),
                     name='inputspec')

    outputnode = Node(interface=IdentityInterface(fields=[
        'func2anat_transform', 'out_reg_file', 'anat2target_transform',
        'transforms', 'transformed_mean', 'segmentation_files', 'anat2target',
        'aparc'
    ]),
                      name='outputspec')

    # Get the subject's freesurfer source directory
    fssource = Node(FreeSurferSource(), name='fssource')
    fssource.run_without_submitting = True
    register.connect(inputnode, 'subject_id', fssource, 'subject_id')
    register.connect(inputnode, 'subjects_dir', fssource, 'subjects_dir')

    convert = Node(freesurfer.MRIConvert(out_type='nii'), name="convert")
    register.connect(fssource, 'T1', convert, 'in_file')

    # Coregister the median to the surface
    bbregister = Node(freesurfer.BBRegister(), name='bbregister')
    bbregister.inputs.init = 'fsl'
    bbregister.inputs.contrast_type = 't2'
    bbregister.inputs.out_fsl_file = True
    bbregister.inputs.epi_mask = True
    register.connect(inputnode, 'subject_id', bbregister, 'subject_id')
    register.connect(inputnode, 'mean_image', bbregister, 'source_file')
    register.connect(inputnode, 'subjects_dir', bbregister, 'subjects_dir')
    """
    Estimate the tissue classes from the anatomical image. But use spm's segment
    as FSL appears to be breaking.
    """

    stripper = Node(fsl.BET(), name='stripper')
    register.connect(convert, 'out_file', stripper, 'in_file')
    fast = Node(fsl.FAST(), name='fast')
    register.connect(stripper, 'out_file', fast, 'in_files')
    """
    Binarize the segmentation
    """

    binarize = MapNode(fsl.ImageMaths(op_string='-nan -thr 0.9 -ero -bin'),
                       iterfield=['in_file'],
                       name='binarize')
    register.connect(fast, 'partial_volume_files', binarize, 'in_file')
    """
    Apply inverse transform to take segmentations to functional space
    """

    applyxfm = MapNode(freesurfer.ApplyVolTransform(inverse=True,
                                                    interp='nearest'),
                       iterfield=['target_file'],
                       name='inverse_transform')
    register.connect(inputnode, 'subjects_dir', applyxfm, 'subjects_dir')
    register.connect(bbregister, 'out_reg_file', applyxfm, 'reg_file')
    register.connect(binarize, 'out_file', applyxfm, 'target_file')
    register.connect(inputnode, 'mean_image', applyxfm, 'source_file')
    """
    Apply inverse transform to aparc file
    """

    aparcxfm = Node(freesurfer.ApplyVolTransform(inverse=True,
                                                 interp='nearest'),
                    name='aparc_inverse_transform')
    register.connect(inputnode, 'subjects_dir', aparcxfm, 'subjects_dir')
    register.connect(bbregister, 'out_reg_file', aparcxfm, 'reg_file')
    register.connect(fssource, ('aparc_aseg', get_aparc_aseg), aparcxfm,
                     'target_file')
    register.connect(inputnode, 'mean_image', aparcxfm, 'source_file')
    """
    Convert the BBRegister transformation to ANTS ITK format
    """

    convert2itk = Node(C3dAffineTool(), name='convert2itk')
    convert2itk.inputs.fsl2ras = True
    convert2itk.inputs.itk_transform = True
    register.connect(bbregister, 'out_fsl_file', convert2itk, 'transform_file')
    register.connect(inputnode, 'mean_image', convert2itk, 'source_file')
    register.connect(stripper, 'out_file', convert2itk, 'reference_file')
    """
    Compute registration between the subject's structural and MNI template
    This is currently set to perform a very quick registration. However, the
    registration can be made significantly more accurate for cortical
    structures by increasing the number of iterations
    All parameters are set using the example from:
    #https://github.com/stnava/ANTs/blob/master/Scripts/newAntsExample.sh
    """

    reg = Node(ants.Registration(), name='antsRegister')
    reg.inputs.output_transform_prefix = "output_"
    reg.inputs.transforms = ['Rigid', 'Affine', 'SyN']
    reg.inputs.transform_parameters = [(0.1, ), (0.1, ), (0.2, 3.0, 0.0)]
    reg.inputs.number_of_iterations = [[10000, 11110, 11110]] * 2 + [[
        100, 30, 20
    ]]
    reg.inputs.dimension = 3
    reg.inputs.write_composite_transform = True
    reg.inputs.collapse_output_transforms = True
    reg.inputs.initial_moving_transform_com = True
    reg.inputs.metric = ['Mattes'] * 2 + [['Mattes', 'CC']]
    reg.inputs.metric_weight = [1] * 2 + [[0.5, 0.5]]
    reg.inputs.radius_or_number_of_bins = [32] * 2 + [[32, 4]]
    reg.inputs.sampling_strategy = ['Regular'] * 2 + [[None, None]]
    reg.inputs.sampling_percentage = [0.3] * 2 + [[None, None]]
    reg.inputs.convergence_threshold = [1.e-8] * 2 + [-0.01]
    reg.inputs.convergence_window_size = [20] * 2 + [5]
    reg.inputs.smoothing_sigmas = [[4, 2, 1]] * 2 + [[1, 0.5, 0]]
    reg.inputs.sigma_units = ['vox'] * 3
    reg.inputs.shrink_factors = [[3, 2, 1]] * 2 + [[4, 2, 1]]
    reg.inputs.use_estimate_learning_rate_once = [True] * 3
    reg.inputs.use_histogram_matching = [False] * 2 + [True]
    reg.inputs.winsorize_lower_quantile = 0.005
    reg.inputs.winsorize_upper_quantile = 0.995
    reg.inputs.float = True
    reg.inputs.output_warped_image = 'output_warped_image.nii.gz'
    reg.inputs.num_threads = 4
    reg.plugin_args = {'qsub_args': '-l nodes=1:ppn=4'}
    register.connect(stripper, 'out_file', reg, 'moving_image')
    register.connect(inputnode, 'target_image', reg, 'fixed_image')
    """
    Concatenate the affine and ants transforms into a list
    """

    merge = Node(Merge(2), iterfield=['in2'], name='mergexfm')
    register.connect(convert2itk, 'itk_transform', merge, 'in2')
    register.connect(reg, 'composite_transform', merge, 'in1')
    """
    Transform the mean image. First to anatomical and then to target
    """

    warpmean = Node(ants.ApplyTransforms(), name='warpmean')
    warpmean.inputs.input_image_type = 3
    warpmean.inputs.interpolation = 'Linear'
    warpmean.inputs.invert_transform_flags = [False, False]
    warpmean.inputs.terminal_output = 'file'
    warpmean.inputs.args = '--float'
    warpmean.inputs.num_threads = 4

    register.connect(inputnode, 'target_image', warpmean, 'reference_image')
    register.connect(inputnode, 'mean_image', warpmean, 'input_image')
    register.connect(merge, 'out', warpmean, 'transforms')
    """
    Assign all the output files
    """

    register.connect(reg, 'warped_image', outputnode, 'anat2target')
    register.connect(warpmean, 'output_image', outputnode, 'transformed_mean')
    register.connect(applyxfm, 'transformed_file', outputnode,
                     'segmentation_files')
    register.connect(aparcxfm, 'transformed_file', outputnode, 'aparc')
    register.connect(bbregister, 'out_fsl_file', outputnode,
                     'func2anat_transform')
    register.connect(bbregister, 'out_reg_file', outputnode, 'out_reg_file')
    register.connect(reg, 'composite_transform', outputnode,
                     'anat2target_transform')
    register.connect(merge, 'out', outputnode, 'transforms')

    return register
def create_fs_reg_workflow(name="registration"):
    """Create a FEAT preprocessing workflow together with freesurfer

    Parameters
    ----------

    ::

        name : name of workflow (default: 'registration')

    Inputs::

        inputspec.source_files : files (filename or list of filenames to register)
        inputspec.mean_image : reference image to use
        inputspec.target_image : registration target

    Outputs::

        outputspec.func2anat_transform : FLIRT transform
        outputspec.anat2target_transform : FLIRT+FNIRT transform
        outputspec.transformed_files : transformed files in target space
        outputspec.transformed_mean : mean image in target space

    Example
    -------

    """

    register = Workflow(name=name)

    inputnode = Node(
        interface=IdentityInterface(
            fields=["source_files", "mean_image", "subject_id", "subjects_dir", "target_image"]
        ),
        name="inputspec",
    )

    outputnode = Node(
        interface=IdentityInterface(
            fields=[
                "func2anat_transform",
                "out_reg_file",
                "anat2target_transform",
                "transforms",
                "transformed_mean",
                "transformed_files",
                "min_cost_file",
                "anat2target",
                "aparc",
                "mean2anat_mask",
            ]
        ),
        name="outputspec",
    )

    # Get the subject's freesurfer source directory
    fssource = Node(FreeSurferSource(), name="fssource")
    fssource.run_without_submitting = True
    register.connect(inputnode, "subject_id", fssource, "subject_id")
    register.connect(inputnode, "subjects_dir", fssource, "subjects_dir")

    convert = Node(freesurfer.MRIConvert(out_type="nii"), name="convert")
    register.connect(fssource, "T1", convert, "in_file")

    # Coregister the median to the surface
    bbregister = Node(freesurfer.BBRegister(registered_file=True), name="bbregister")
    bbregister.inputs.init = "fsl"
    bbregister.inputs.contrast_type = "t2"
    bbregister.inputs.out_fsl_file = True
    bbregister.inputs.epi_mask = True
    register.connect(inputnode, "subject_id", bbregister, "subject_id")
    register.connect(inputnode, "mean_image", bbregister, "source_file")
    register.connect(inputnode, "subjects_dir", bbregister, "subjects_dir")

    # Create a mask of the median coregistered to the anatomical image
    mean2anat_mask = Node(fsl.BET(mask=True), name="mean2anat_mask")
    register.connect(bbregister, "registered_file", mean2anat_mask, "in_file")

    """
    use aparc+aseg's brain mask
    """

    binarize = Node(fs.Binarize(min=0.5, out_type="nii.gz", dilate=1), name="binarize_aparc")
    register.connect(fssource, ("aparc_aseg", get_aparc_aseg), binarize, "in_file")

    stripper = Node(fsl.ApplyMask(), name="stripper")
    register.connect(binarize, "binary_file", stripper, "mask_file")
    register.connect(convert, "out_file", stripper, "in_file")

    """
    Apply inverse transform to aparc file
    """
    aparcxfm = Node(freesurfer.ApplyVolTransform(inverse=True, interp="nearest"), name="aparc_inverse_transform")
    register.connect(inputnode, "subjects_dir", aparcxfm, "subjects_dir")
    register.connect(bbregister, "out_reg_file", aparcxfm, "reg_file")
    register.connect(fssource, ("aparc_aseg", get_aparc_aseg), aparcxfm, "target_file")
    register.connect(inputnode, "mean_image", aparcxfm, "source_file")

    """
    Convert the BBRegister transformation to ANTS ITK format
    """

    convert2itk = Node(C3dAffineTool(), name="convert2itk")
    convert2itk.inputs.fsl2ras = True
    convert2itk.inputs.itk_transform = True
    register.connect(bbregister, "out_fsl_file", convert2itk, "transform_file")
    register.connect(inputnode, "mean_image", convert2itk, "source_file")
    register.connect(stripper, "out_file", convert2itk, "reference_file")

    """
    Compute registration between the subject's structural and MNI template
    This is currently set to perform a very quick registration. However, the
    registration can be made significantly more accurate for cortical
    structures by increasing the number of iterations
    All parameters are set using the example from:
    #https://github.com/stnava/ANTs/blob/master/Scripts/newAntsExample.sh
    """

    reg = Node(ants.Registration(), name="antsRegister")
    reg.inputs.output_transform_prefix = "output_"
    reg.inputs.transforms = ["Rigid", "Affine", "SyN"]
    reg.inputs.transform_parameters = [(0.1,), (0.1,), (0.2, 3.0, 0.0)]
    reg.inputs.number_of_iterations = [[10000, 11110, 11110]] * 2 + [[100, 30, 20]]
    reg.inputs.dimension = 3
    reg.inputs.write_composite_transform = True
    reg.inputs.collapse_output_transforms = True
    reg.inputs.initial_moving_transform_com = True
    reg.inputs.metric = ["Mattes"] * 2 + [["Mattes", "CC"]]
    reg.inputs.metric_weight = [1] * 2 + [[0.5, 0.5]]
    reg.inputs.radius_or_number_of_bins = [32] * 2 + [[32, 4]]
    reg.inputs.sampling_strategy = ["Regular"] * 2 + [[None, None]]
    reg.inputs.sampling_percentage = [0.3] * 2 + [[None, None]]
    reg.inputs.convergence_threshold = [1.0e-8] * 2 + [-0.01]
    reg.inputs.convergence_window_size = [20] * 2 + [5]
    reg.inputs.smoothing_sigmas = [[4, 2, 1]] * 2 + [[1, 0.5, 0]]
    reg.inputs.sigma_units = ["vox"] * 3
    reg.inputs.shrink_factors = [[3, 2, 1]] * 2 + [[4, 2, 1]]
    reg.inputs.use_estimate_learning_rate_once = [True] * 3
    reg.inputs.use_histogram_matching = [False] * 2 + [True]
    reg.inputs.winsorize_lower_quantile = 0.005
    reg.inputs.winsorize_upper_quantile = 0.995
    reg.inputs.args = "--float"
    reg.inputs.output_warped_image = "output_warped_image.nii.gz"
    reg.inputs.num_threads = 4
    reg.plugin_args = {"qsub_args": "-pe orte 4", "sbatch_args": "--mem=6G -c 4"}
    register.connect(stripper, "out_file", reg, "moving_image")
    register.connect(inputnode, "target_image", reg, "fixed_image")

    """
    Concatenate the affine and ants transforms into a list
    """

    pickfirst = lambda x: x[0]

    merge = Node(Merge(2), iterfield=["in2"], name="mergexfm")
    register.connect(convert2itk, "itk_transform", merge, "in2")
    register.connect(reg, ("composite_transform", pickfirst), merge, "in1")

    """
    Transform the mean image. First to anatomical and then to target
    """
    warpmean = Node(ants.ApplyTransforms(), name="warpmean")
    warpmean.inputs.input_image_type = 0
    warpmean.inputs.interpolation = "Linear"
    warpmean.inputs.invert_transform_flags = [False, False]
    warpmean.inputs.terminal_output = "file"
    warpmean.inputs.args = "--float"
    # warpmean.inputs.num_threads = 4
    # warpmean.plugin_args = {'sbatch_args': '--mem=4G -c 4'}

    """
    Transform the remaining images. First to anatomical and then to target
    """

    warpall = pe.MapNode(ants.ApplyTransforms(), iterfield=["input_image"], name="warpall")
    warpall.inputs.input_image_type = 0
    warpall.inputs.interpolation = "Linear"
    warpall.inputs.invert_transform_flags = [False, False]
    warpall.inputs.terminal_output = "file"
    warpall.inputs.args = "--float"
    warpall.inputs.num_threads = 2
    warpall.plugin_args = {"sbatch_args": "--mem=6G -c 2"}

    """
    Assign all the output files
    """

    register.connect(warpmean, "output_image", outputnode, "transformed_mean")
    register.connect(warpall, "output_image", outputnode, "transformed_files")

    register.connect(inputnode, "target_image", warpmean, "reference_image")
    register.connect(inputnode, "mean_image", warpmean, "input_image")
    register.connect(merge, "out", warpmean, "transforms")
    register.connect(inputnode, "target_image", warpall, "reference_image")
    register.connect(inputnode, "source_files", warpall, "input_image")
    register.connect(merge, "out", warpall, "transforms")

    """
    Assign all the output files
    """

    register.connect(reg, "warped_image", outputnode, "anat2target")
    register.connect(aparcxfm, "transformed_file", outputnode, "aparc")
    register.connect(bbregister, "out_fsl_file", outputnode, "func2anat_transform")
    register.connect(bbregister, "out_reg_file", outputnode, "out_reg_file")
    register.connect(bbregister, "min_cost_file", outputnode, "min_cost_file")
    register.connect(mean2anat_mask, "mask_file", outputnode, "mean2anat_mask")
    register.connect(reg, "composite_transform", outputnode, "anat2target_transform")
    register.connect(merge, "out", outputnode, "transforms")

    return register
Exemple #6
0
# Create a datasource node to get the T1 file
datasource = Node(DataGrabber(infields=['subject_id'], outfields=info.keys()),
                  name='datasource')
datasource.inputs.template = '%s/%s'
datasource.inputs.base_directory = os.path.abspath(data_dir)
datasource.inputs.field_template = dict(T1='%s/s1/anatomy/T1_002.nii.gz')
datasource.inputs.template_args = info
datasource.inputs.sort_filelist = True

reconall_node = Node(ReconAll(), name='reconall_node')
reconall_node.inputs.openmp = 2
reconall_node.inputs.args = '-hippocampal-subfields-T1'
reconall_node.inputs.subjects_dir = '/home/data/madlab/surfaces/emuR01'
reconall_node.plugin_args = {
    'sbatch_args': ('-p investor --qos pq_madlab -n 2'),
    'overwrite': True
}

wf = Workflow(name='fsrecon')

wf.connect(infosource, 'subject_id', datasource, 'subject_id')
wf.connect(infosource, 'subject_id', reconall_node, 'subject_id')
wf.connect(datasource, 'T1', reconall_node, 'T1_files')

wf.base_dir = os.path.abspath('/scratch/madlab/emu/')
#wf.config['execution']['job_finished_timeout'] = 65

wf.run(plugin='SLURM',
       plugin_args={
           'sbatch_args': ('-p investor --qos pq_madlab -N 1 -n 1'),
           'overwrite': True
Exemple #7
0
def create_workflow(files,
                    subject_id,
                    n_vol=0,
                    despike=True,
                    TR=None,
                    slice_times=None,
                    slice_thickness=None,
                    fieldmap_images=[],
                    norm_threshold=1,
                    num_components=6,
                    vol_fwhm=None,
                    surf_fwhm=None,
                    lowpass_freq=-1,
                    highpass_freq=-1,
                    sink_directory=os.getcwd(),
                    FM_TEdiff=2.46,
                    FM_sigma=2,
                    FM_echo_spacing=.7,
                    target_subject=['fsaverage3', 'fsaverage4'],
                    name='resting'):

    wf = Workflow(name=name)

    # Skip starting volumes
    remove_vol = MapNode(fsl.ExtractROI(t_min=n_vol, t_size=-1),
                         iterfield=['in_file'],
                         name="remove_volumes")
    remove_vol.inputs.in_file = files

    # Run AFNI's despike. This is always run, however, whether this is fed to
    # realign depends on the input configuration
    despiker = MapNode(afni.Despike(outputtype='NIFTI_GZ'),
                       iterfield=['in_file'],
                       name='despike')
    #despiker.plugin_args = {'qsub_args': '-l nodes=1:ppn='}

    wf.connect(remove_vol, 'roi_file', despiker, 'in_file')

    # Run Nipy joint slice timing and realignment algorithm
    realign = Node(nipy.SpaceTimeRealigner(), name='realign')
    realign.inputs.tr = TR
    realign.inputs.slice_times = slice_times
    realign.inputs.slice_info = 2

    if despike:
        wf.connect(despiker, 'out_file', realign, 'in_file')
    else:
        wf.connect(remove_vol, 'roi_file', realign, 'in_file')

    # Comute TSNR on realigned data regressing polynomials upto order 2
    tsnr = MapNode(TSNR(regress_poly=2), iterfield=['in_file'], name='tsnr')
    wf.connect(realign, 'out_file', tsnr, 'in_file')

    # Compute the median image across runs
    calc_median = Node(Function(input_names=['in_files'],
                                output_names=['median_file'],
                                function=median,
                                imports=imports),
                       name='median')
    wf.connect(tsnr, 'detrended_file', calc_median, 'in_files')

    # Coregister the median to the surface
    register = Node(freesurfer.BBRegister(), name='bbregister')
    register.inputs.subject_id = subject_id
    register.inputs.init = 'fsl'
    register.inputs.contrast_type = 't2'
    register.inputs.out_fsl_file = True
    register.inputs.epi_mask = True

    # Compute fieldmaps and unwarp using them
    if fieldmap_images:
        fieldmap = Node(interface=EPIDeWarp(), name='fieldmap_unwarp')
        fieldmap.inputs.tediff = FM_TEdiff
        fieldmap.inputs.esp = FM_echo_spacing
        fieldmap.inputs.sigma = FM_sigma
        fieldmap.inputs.mag_file = fieldmap_images[0]
        fieldmap.inputs.dph_file = fieldmap_images[1]
        wf.connect(calc_median, 'median_file', fieldmap, 'exf_file')

        dewarper = MapNode(interface=fsl.FUGUE(),
                           iterfield=['in_file'],
                           name='dewarper')
        wf.connect(tsnr, 'detrended_file', dewarper, 'in_file')
        wf.connect(fieldmap, 'exf_mask', dewarper, 'mask_file')
        wf.connect(fieldmap, 'vsm_file', dewarper, 'shift_in_file')
        wf.connect(fieldmap, 'exfdw', register, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', register, 'source_file')

    # Get the subject's freesurfer source directory
    fssource = Node(FreeSurferSource(), name='fssource')
    fssource.inputs.subject_id = subject_id
    fssource.inputs.subjects_dir = os.environ['SUBJECTS_DIR']

    # Extract wm+csf, brain masks by eroding freesurfer labels and then
    # transform the masks into the space of the median
    wmcsf = Node(freesurfer.Binarize(), name='wmcsfmask')
    mask = wmcsf.clone('anatmask')
    wmcsftransform = Node(freesurfer.ApplyVolTransform(inverse=True,
                                                       interp='nearest'),
                          name='wmcsftransform')
    wmcsftransform.inputs.subjects_dir = os.environ['SUBJECTS_DIR']
    wmcsf.inputs.wm_ven_csf = True
    wmcsf.inputs.match = [4, 5, 14, 15, 24, 31, 43, 44, 63]
    wmcsf.inputs.binary_file = 'wmcsf.nii.gz'
    wmcsf.inputs.erode = int(np.ceil(slice_thickness))
    wf.connect(fssource, ('aparc_aseg', get_aparc_aseg), wmcsf, 'in_file')
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', wmcsftransform, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', wmcsftransform, 'source_file')
    wf.connect(register, 'out_reg_file', wmcsftransform, 'reg_file')
    wf.connect(wmcsf, 'binary_file', wmcsftransform, 'target_file')

    mask.inputs.binary_file = 'mask.nii.gz'
    mask.inputs.dilate = int(np.ceil(slice_thickness)) + 1
    mask.inputs.erode = int(np.ceil(slice_thickness))
    mask.inputs.min = 0.5
    wf.connect(fssource, ('aparc_aseg', get_aparc_aseg), mask, 'in_file')
    masktransform = wmcsftransform.clone("masktransform")
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', masktransform, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', masktransform, 'source_file')
    wf.connect(register, 'out_reg_file', masktransform, 'reg_file')
    wf.connect(mask, 'binary_file', masktransform, 'target_file')

    # Compute Art outliers
    art = Node(interface=ArtifactDetect(use_differences=[True, False],
                                        use_norm=True,
                                        norm_threshold=norm_threshold,
                                        zintensity_threshold=3,
                                        parameter_source='NiPy',
                                        bound_by_brainmask=True,
                                        save_plot=False,
                                        mask_type='file'),
               name="art")
    if fieldmap_images:
        wf.connect(dewarper, 'unwarped_file', art, 'realigned_files')
    else:
        wf.connect(tsnr, 'detrended_file', art, 'realigned_files')
    wf.connect(realign, 'par_file', art, 'realignment_parameters')
    wf.connect(masktransform, 'transformed_file', art, 'mask_file')

    # Compute motion regressors
    motreg = Node(Function(
        input_names=['motion_params', 'order', 'derivatives'],
        output_names=['out_files'],
        function=motion_regressors,
        imports=imports),
                  name='getmotionregress')
    wf.connect(realign, 'par_file', motreg, 'motion_params')

    # Create a filter to remove motion and art confounds
    createfilter1 = Node(Function(
        input_names=['motion_params', 'comp_norm', 'outliers'],
        output_names=['out_files'],
        function=build_filter1,
        imports=imports),
                         name='makemotionbasedfilter')
    wf.connect(motreg, 'out_files', createfilter1, 'motion_params')
    wf.connect(art, 'norm_files', createfilter1, 'comp_norm')
    wf.connect(art, 'outlier_files', createfilter1, 'outliers')

    # Filter the motion and art confounds
    filter1 = MapNode(fsl.GLM(out_res_name='timeseries.nii.gz', demean=True),
                      iterfield=['in_file', 'design'],
                      name='filtermotion')
    if fieldmap_images:
        wf.connect(dewarper, 'unwarped_file', filter1, 'in_file')
    else:
        wf.connect(tsnr, 'detrended_file', filter1, 'in_file')
    wf.connect(createfilter1, 'out_files', filter1, 'design')
    wf.connect(masktransform, 'transformed_file', filter1, 'mask')

    # Create a filter to remove noise components based on white matter and CSF
    createfilter2 = MapNode(Function(
        input_names=['realigned_file', 'mask_file', 'num_components'],
        output_names=['out_files'],
        function=extract_noise_components,
        imports=imports),
                            iterfield=['realigned_file'],
                            name='makecompcorrfilter')
    createfilter2.inputs.num_components = num_components
    wf.connect(filter1, 'out_res', createfilter2, 'realigned_file')
    wf.connect(masktransform, 'transformed_file', createfilter2, 'mask_file')

    # Filter noise components
    filter2 = MapNode(fsl.GLM(out_res_name='timeseries_cleaned.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design'],
                      name='filtercompcorr')
    wf.connect(filter1, 'out_res', filter2, 'in_file')
    wf.connect(createfilter2, 'out_files', filter2, 'design')
    wf.connect(masktransform, 'transformed_file', filter2, 'mask')

    # Smoothing using surface and volume smoothing
    smooth = MapNode(freesurfer.Smooth(), iterfield=['in_file'], name='smooth')
    smooth.inputs.proj_frac_avg = (0.1, 0.9, 0.1)
    if surf_fwhm is None:
        surf_fwhm = 5 * slice_thickness
    smooth.inputs.surface_fwhm = surf_fwhm
    if vol_fwhm is None:
        vol_fwhm = 2 * slice_thickness
    smooth.inputs.vol_fwhm = vol_fwhm
    wf.connect(filter2, 'out_res', smooth, 'in_file')
    wf.connect(register, 'out_reg_file', smooth, 'reg_file')

    # Bandpass filter the data
    bandpass = MapNode(fsl.TemporalFilter(),
                       iterfield=['in_file'],
                       name='bandpassfilter')
    if highpass_freq < 0:
        bandpass.inputs.highpass_sigma = -1
    else:
        bandpass.inputs.highpass_sigma = 1. / (2 * TR * highpass_freq)
    if lowpass_freq < 0:
        bandpass.inputs.lowpass_sigma = -1
    else:
        bandpass.inputs.lowpass_sigma = 1. / (2 * TR * lowpass_freq)
    wf.connect(smooth, 'smoothed_file', bandpass, 'in_file')

    # Convert aparc to subject functional space
    aparctransform = wmcsftransform.clone("aparctransform")
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', aparctransform, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', aparctransform, 'source_file')
    wf.connect(register, 'out_reg_file', aparctransform, 'reg_file')
    wf.connect(fssource, ('aparc_aseg', get_aparc_aseg), aparctransform,
               'target_file')

    # Sample the average time series in aparc ROIs
    sampleaparc = MapNode(freesurfer.SegStats(avgwf_txt_file=True,
                                              default_color_table=True),
                          iterfield=['in_file'],
                          name='aparc_ts')
    sampleaparc.inputs.segment_id = ([8] + range(10, 14) + [17, 18, 26, 47] +
                                     range(49, 55) + [58] + range(1001, 1036) +
                                     range(2001, 2036))

    wf.connect(aparctransform, 'transformed_file', sampleaparc,
               'segmentation_file')
    wf.connect(bandpass, 'out_file', sampleaparc, 'in_file')

    # Sample the time series onto the surface of the target surface. Performs
    # sampling into left and right hemisphere
    target = Node(IdentityInterface(fields=['target_subject']), name='target')
    target.iterables = ('target_subject', filename_to_list(target_subject))

    samplerlh = MapNode(freesurfer.SampleToSurface(),
                        iterfield=['source_file'],
                        name='sampler_lh')
    samplerlh.inputs.sampling_method = "average"
    samplerlh.inputs.sampling_range = (0.1, 0.9, 0.1)
    samplerlh.inputs.sampling_units = "frac"
    samplerlh.inputs.interp_method = "trilinear"
    #samplerlh.inputs.cortex_mask = True
    samplerlh.inputs.out_type = 'niigz'
    samplerlh.inputs.subjects_dir = os.environ['SUBJECTS_DIR']

    samplerrh = samplerlh.clone('sampler_rh')

    samplerlh.inputs.hemi = 'lh'
    wf.connect(bandpass, 'out_file', samplerlh, 'source_file')
    wf.connect(register, 'out_reg_file', samplerlh, 'reg_file')
    wf.connect(target, 'target_subject', samplerlh, 'target_subject')

    samplerrh.set_input('hemi', 'rh')
    wf.connect(bandpass, 'out_file', samplerrh, 'source_file')
    wf.connect(register, 'out_reg_file', samplerrh, 'reg_file')
    wf.connect(target, 'target_subject', samplerrh, 'target_subject')

    # Combine left and right hemisphere to text file
    combiner = MapNode(Function(input_names=['left', 'right'],
                                output_names=['out_file'],
                                function=combine_hemi,
                                imports=imports),
                       iterfield=['left', 'right'],
                       name="combiner")
    wf.connect(samplerlh, 'out_file', combiner, 'left')
    wf.connect(samplerrh, 'out_file', combiner, 'right')

    # Compute registration between the subject's structural and MNI template
    # This is currently set to perform a very quick registration. However, the
    # registration can be made significantly more accurate for cortical
    # structures by increasing the number of iterations
    # All parameters are set using the example from:
    # https://github.com/stnava/ANTs/blob/master/Scripts/newAntsExample.sh
    reg = Node(ants.Registration(), name='antsRegister')
    reg.inputs.output_transform_prefix = "output_"
    reg.inputs.transforms = ['Translation', 'Rigid', 'Affine', 'SyN']
    reg.inputs.transform_parameters = [(0.1, ), (0.1, ), (0.1, ),
                                       (0.2, 3.0, 0.0)]
    # reg.inputs.number_of_iterations = ([[10000, 111110, 11110]]*3 +
    #                                    [[100, 50, 30]])
    reg.inputs.number_of_iterations = [[100, 100, 100]] * 3 + [[100, 20, 10]]
    reg.inputs.dimension = 3
    reg.inputs.write_composite_transform = True
    reg.inputs.collapse_output_transforms = False
    reg.inputs.metric = ['Mattes'] * 3 + [['Mattes', 'CC']]
    reg.inputs.metric_weight = [1] * 3 + [[0.5, 0.5]]
    reg.inputs.radius_or_number_of_bins = [32] * 3 + [[32, 4]]
    reg.inputs.sampling_strategy = ['Regular'] * 3 + [[None, None]]
    reg.inputs.sampling_percentage = [0.3] * 3 + [[None, None]]
    reg.inputs.convergence_threshold = [1.e-8] * 3 + [-0.01]
    reg.inputs.convergence_window_size = [20] * 3 + [5]
    reg.inputs.smoothing_sigmas = [[4, 2, 1]] * 3 + [[1, 0.5, 0]]
    reg.inputs.sigma_units = ['vox'] * 4
    reg.inputs.shrink_factors = [[6, 4, 2]] + [[3, 2, 1]] * 2 + [[4, 2, 1]]
    reg.inputs.use_estimate_learning_rate_once = [True] * 4
    reg.inputs.use_histogram_matching = [False] * 3 + [True]
    reg.inputs.output_warped_image = 'output_warped_image.nii.gz'
    reg.inputs.fixed_image = \
        os.path.abspath('OASIS-30_Atropos_template_in_MNI152_2mm.nii.gz')
    reg.inputs.num_threads = 4
    reg.plugin_args = {'qsub_args': '-l nodes=1:ppn=4'}

    # Convert T1.mgz to nifti for using with ANTS
    convert = Node(freesurfer.MRIConvert(out_type='niigz'), name='convert2nii')
    wf.connect(fssource, 'T1', convert, 'in_file')

    # Mask the T1.mgz file with the brain mask computed earlier
    maskT1 = Node(fsl.BinaryMaths(operation='mul'), name='maskT1')
    wf.connect(mask, 'binary_file', maskT1, 'operand_file')
    wf.connect(convert, 'out_file', maskT1, 'in_file')
    wf.connect(maskT1, 'out_file', reg, 'moving_image')

    # Convert the BBRegister transformation to ANTS ITK format
    convert2itk = MapNode(C3dAffineTool(),
                          iterfield=['transform_file', 'source_file'],
                          name='convert2itk')
    convert2itk.inputs.fsl2ras = True
    convert2itk.inputs.itk_transform = True
    wf.connect(register, 'out_fsl_file', convert2itk, 'transform_file')
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', convert2itk, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', convert2itk, 'source_file')
    wf.connect(convert, 'out_file', convert2itk, 'reference_file')

    # Concatenate the affine and ants transforms into a list
    pickfirst = lambda x: x[0]
    merge = MapNode(Merge(2), iterfield=['in2'], name='mergexfm')
    wf.connect(convert2itk, 'itk_transform', merge, 'in2')
    wf.connect(reg, ('composite_transform', pickfirst), merge, 'in1')

    # Apply the combined transform to the time series file
    sample2mni = MapNode(ants.ApplyTransforms(),
                         iterfield=['input_image', 'transforms'],
                         name='sample2mni')
    sample2mni.inputs.input_image_type = 3
    sample2mni.inputs.interpolation = 'BSpline'
    sample2mni.inputs.invert_transform_flags = [False, False]
    sample2mni.inputs.reference_image = \
        os.path.abspath('OASIS-30_Atropos_template_in_MNI152_2mm.nii.gz')
    sample2mni.inputs.terminal_output = 'file'
    wf.connect(bandpass, 'out_file', sample2mni, 'input_image')
    wf.connect(merge, 'out', sample2mni, 'transforms')

    # Sample the time series file for each subcortical roi
    ts2txt = MapNode(Function(
        input_names=['timeseries_file', 'label_file', 'indices'],
        output_names=['out_file'],
        function=extract_subrois,
        imports=imports),
                     iterfield=['timeseries_file'],
                     name='getsubcortts')
    ts2txt.inputs.indices = [8] + range(10, 14) + [17, 18, 26, 47] +\
                            range(49, 55) + [58]
    ts2txt.inputs.label_file = \
        os.path.abspath(('OASIS-TRT-20_jointfusion_DKT31_CMA_labels_in_MNI152_'
                         '2mm.nii.gz'))
    wf.connect(sample2mni, 'output_image', ts2txt, 'timeseries_file')

    # Save the relevant data into an output directory
    datasink = Node(interface=DataSink(), name="datasink")
    datasink.inputs.base_directory = sink_directory
    datasink.inputs.container = subject_id
    datasink.inputs.substitutions = [('_target_subject_', '')]
    datasink.inputs.regexp_substitutions = (r'(/_.*(\d+/))', r'/run\2')
    wf.connect(despiker, 'out_file', datasink, 'resting.qa.despike')
    wf.connect(realign, 'par_file', datasink, 'resting.qa.motion')
    wf.connect(tsnr, 'tsnr_file', datasink, 'resting.qa.tsnr')
    wf.connect(tsnr, 'mean_file', datasink, 'resting.qa.tsnr.@mean')
    wf.connect(tsnr, 'stddev_file', datasink, 'resting.qa.@tsnr_stddev')
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', datasink, 'resting.reference')
    else:
        wf.connect(calc_median, 'median_file', datasink, 'resting.reference')
    wf.connect(art, 'norm_files', datasink, 'resting.qa.art.@norm')
    wf.connect(art, 'intensity_files', datasink, 'resting.qa.art.@intensity')
    wf.connect(art, 'outlier_files', datasink, 'resting.qa.art.@outlier_files')
    wf.connect(mask, 'binary_file', datasink, 'resting.mask')
    wf.connect(masktransform, 'transformed_file', datasink,
               'resting.mask.@transformed_file')
    wf.connect(register, 'out_reg_file', datasink,
               'resting.registration.bbreg')
    wf.connect(reg, ('composite_transform', pickfirst), datasink,
               'resting.registration.ants')
    wf.connect(register, 'min_cost_file', datasink,
               'resting.qa.bbreg.@mincost')
    wf.connect(smooth, 'smoothed_file', datasink,
               'resting.timeseries.fullpass')
    wf.connect(bandpass, 'out_file', datasink, 'resting.timeseries.bandpassed')
    wf.connect(sample2mni, 'output_image', datasink, 'resting.timeseries.mni')
    wf.connect(createfilter1, 'out_files', datasink,
               'resting.regress.@regressors')
    wf.connect(createfilter2, 'out_files', datasink,
               'resting.regress.@compcorr')
    wf.connect(sampleaparc, 'summary_file', datasink,
               'resting.parcellations.aparc')
    wf.connect(sampleaparc, 'avgwf_txt_file', datasink,
               'resting.parcellations.aparc.@avgwf')
    wf.connect(ts2txt, 'out_file', datasink,
               'resting.parcellations.grayo.@subcortical')
    datasink2 = Node(interface=DataSink(), name="datasink2")
    datasink2.inputs.base_directory = sink_directory
    datasink2.inputs.container = subject_id
    datasink2.inputs.substitutions = [('_target_subject_', '')]
    datasink2.inputs.regexp_substitutions = (r'(/_.*(\d+/))', r'/run\2')
    wf.connect(combiner, 'out_file', datasink2,
               'resting.parcellations.grayo.@surface')
    return wf
def group_multregress_openfmri(dataset_dir, model_id=None, task_id=None, l1output_dir=None, out_dir=None, 
                               no_reversal=False, plugin=None, plugin_args=None, flamemodel='flame1',
                               nonparametric=False, use_spm=False):

    meta_workflow = Workflow(name='mult_regress')
    meta_workflow.base_dir = work_dir
    for task in task_id:
        task_name = get_taskname(dataset_dir, task)
        cope_ids = l1_contrasts_num(model_id, task_name, dataset_dir)
        regressors_needed, contrasts, groups, subj_list = get_sub_vars(dataset_dir, task_name, model_id)
        for idx, contrast in enumerate(contrasts):
            wk = Workflow(name='model_%03d_task_%03d_contrast_%s' % (model_id, task, contrast[0][0]))

            info = Node(util.IdentityInterface(fields=['model_id', 'task_id', 'dataset_dir', 'subj_list']),
                        name='infosource')
            info.inputs.model_id = model_id
            info.inputs.task_id = task
            info.inputs.dataset_dir = dataset_dir
            
            dg = Node(DataGrabber(infields=['model_id', 'task_id', 'cope_id'],
                                  outfields=['copes', 'varcopes']), name='grabber')
            dg.inputs.template = os.path.join(l1output_dir,
                                              'model%03d/task%03d/%s/%scopes/%smni/%scope%02d.nii%s')
            if use_spm:
                dg.inputs.template_args['copes'] = [['model_id', 'task_id', subj_list, '', 'spm/',
                                                     '', 'cope_id', '']]
                dg.inputs.template_args['varcopes'] = [['model_id', 'task_id', subj_list, 'var', 'spm/',
                                                        'var', 'cope_id', '.gz']]
            else:
                dg.inputs.template_args['copes'] = [['model_id', 'task_id', subj_list, '', '', '', 
                                                     'cope_id', '.gz']]
                dg.inputs.template_args['varcopes'] = [['model_id', 'task_id', subj_list, 'var', '',
                                                        'var', 'cope_id', '.gz']]
            dg.iterables=('cope_id', cope_ids)
            dg.inputs.sort_filelist = False

            wk.connect(info, 'model_id', dg, 'model_id')
            wk.connect(info, 'task_id', dg, 'task_id')

            model = Node(MultipleRegressDesign(), name='l2model')
            model.inputs.groups = groups
            model.inputs.contrasts = contrasts[idx]
            model.inputs.regressors = regressors_needed[idx]
            
            mergecopes = Node(Merge(dimension='t'), name='merge_copes')
            wk.connect(dg, 'copes', mergecopes, 'in_files')
            
            if flamemodel != 'ols':
                mergevarcopes = Node(Merge(dimension='t'), name='merge_varcopes')
                wk.connect(dg, 'varcopes', mergevarcopes, 'in_files')
            
            mask_file = fsl.Info.standard_image('MNI152_T1_2mm_brain_mask.nii.gz')
            flame = Node(FLAMEO(), name='flameo')
            flame.inputs.mask_file =  mask_file
            flame.inputs.run_mode = flamemodel
            #flame.inputs.infer_outliers = True

            wk.connect(model, 'design_mat', flame, 'design_file')
            wk.connect(model, 'design_con', flame, 't_con_file')
            wk.connect(mergecopes, 'merged_file', flame, 'cope_file')
            if flamemodel != 'ols':
                wk.connect(mergevarcopes, 'merged_file', flame, 'var_cope_file')
            wk.connect(model, 'design_grp', flame, 'cov_split_file')
            
            if nonparametric:
                palm = Node(Function(input_names=['cope_file', 'design_file', 'contrast_file', 
                                                  'group_file', 'mask_file', 'cluster_threshold'],
                                     output_names=['palm_outputs'],
                                     function=run_palm),
                            name='palm')
                palm.inputs.cluster_threshold = 3.09
                palm.inputs.mask_file = mask_file
                palm.plugin_args = {'sbatch_args': '-p om_all_nodes -N1 -c2 --mem=10G', 'overwrite': True}
                wk.connect(model, 'design_mat', palm, 'design_file')
                wk.connect(model, 'design_con', palm, 'contrast_file')
                wk.connect(mergecopes, 'merged_file', palm, 'cope_file')
                wk.connect(model, 'design_grp', palm, 'group_file')
                
            smoothest = Node(SmoothEstimate(), name='smooth_estimate')
            wk.connect(flame, 'zstats', smoothest, 'zstat_file')
            smoothest.inputs.mask_file = mask_file
        
            cluster = Node(Cluster(), name='cluster')
            wk.connect(smoothest,'dlh', cluster, 'dlh')
            wk.connect(smoothest, 'volume', cluster, 'volume')
            cluster.inputs.connectivity = 26
            cluster.inputs.threshold = 2.3
            cluster.inputs.pthreshold = 0.05
            cluster.inputs.out_threshold_file = True
            cluster.inputs.out_index_file = True
            cluster.inputs.out_localmax_txt_file = True
            
            wk.connect(flame, 'zstats', cluster, 'in_file')
    
            ztopval = Node(ImageMaths(op_string='-ztop', suffix='_pval'),
                           name='z2pval')
            wk.connect(flame, 'zstats', ztopval,'in_file')
            
            sinker = Node(DataSink(), name='sinker')
            sinker.inputs.base_directory = os.path.join(out_dir, 'task%03d' % task, contrast[0][0])
            sinker.inputs.substitutions = [('_cope_id', 'contrast'),
                                           ('_maths_', '_reversed_')]
            
            wk.connect(flame, 'zstats', sinker, 'stats')
            wk.connect(cluster, 'threshold_file', sinker, 'stats.@thr')
            wk.connect(cluster, 'index_file', sinker, 'stats.@index')
            wk.connect(cluster, 'localmax_txt_file', sinker, 'stats.@localmax')
            if nonparametric:
                wk.connect(palm, 'palm_outputs', sinker, 'stats.palm')

            if not no_reversal:
                zstats_reverse = Node( BinaryMaths()  , name='zstats_reverse')
                zstats_reverse.inputs.operation = 'mul'
                zstats_reverse.inputs.operand_value = -1
                wk.connect(flame, 'zstats', zstats_reverse, 'in_file')
                
                cluster2=cluster.clone(name='cluster2')
                wk.connect(smoothest, 'dlh', cluster2, 'dlh')
                wk.connect(smoothest, 'volume', cluster2, 'volume')
                wk.connect(zstats_reverse, 'out_file', cluster2, 'in_file')
                
                ztopval2 = ztopval.clone(name='ztopval2')
                wk.connect(zstats_reverse, 'out_file', ztopval2, 'in_file')
                
                wk.connect(zstats_reverse, 'out_file', sinker, 'stats.@neg')
                wk.connect(cluster2, 'threshold_file', sinker, 'stats.@neg_thr')
                wk.connect(cluster2, 'index_file',sinker, 'stats.@neg_index')
                wk.connect(cluster2, 'localmax_txt_file', sinker, 'stats.@neg_localmax')
            meta_workflow.add_nodes([wk])
    return meta_workflow
life.inputs.bvals = fbval
life.inputs.bvecs = fbvec

# define inputs to the workflow
infosource = Node(IdentityInterface(fields=[
    'subject_id',
    'atlas_file',
]),
                  name='infosource')
infosource.inputs.subject_id = sids[0]
infosource.inputs.atlas_file = atlas_file

# create the output data sink
ds = Node(DataSink(parameterization=False), name='sinker')
ds.inputs.base_directory = out_dir
ds.plugin_args = {'overwrite': True}

# create the nipype workflow and connect nodes' inputs/outputs
wf = Workflow(name='exvivo')
wf.config['execution']['crashfile_format'] = 'txt'

wf.connect(infosource, 'atlas_file', region_extracter, 'atlas_file')
wf.connect(iden, 'label', region_extracter, 'label')

wf.connect(region_extracter, 'single_region', filter_streamlines,
           'target_mask')
wf.connect(region_extracter, 'label', filter_streamlines, 'label')

wf.connect(iden_target, 'target_label', region_extracter_target, 'label')
wf.connect(region_extracter, 'atlas_file', region_extracter_target,
           'atlas_file')
Exemple #10
0
def create_workflow(files,
                    target_file,
                    subject_id,
                    TR,
                    slice_times,
                    norm_threshold=1,
                    num_components=5,
                    vol_fwhm=None,
                    surf_fwhm=None,
                    lowpass_freq=-1,
                    highpass_freq=-1,
                    subjects_dir=None,
                    sink_directory=os.getcwd(),
                    target_subject=['fsaverage3', 'fsaverage4'],
                    name='resting'):

    wf = Workflow(name=name)

    realign = Node(nipy.SpaceTimeRealigner(), name="spacetime_realign")
    realign.inputs.slice_times = slice_times
    realign.inputs.tr = TR
    realign.inputs.slice_info = 2
    realign.inputs.in_file = files
    realign.plugin_args = {'sbatch_args': '-c%d' % 4}


    # Comute TSNR on realigned data regressing polynomials upto order 2
    tsnr = MapNode(TSNR(regress_poly=2), iterfield=['in_file'], name='tsnr')
    wf.connect(realign,"out_file", tsnr, "in_file")

    # Compute the median image across runs
    calc_median = Node(Function(input_names=['in_files'],
                                output_names=['median_file'],
                                function=median,
                                imports=imports),
                       name='median')
    wf.connect(tsnr, 'detrended_file', calc_median, 'in_files')

    """Segment and Register
    """
    registration = create_reg_workflow(name='registration')
    wf.connect(calc_median, 'median_file', registration, 'inputspec.mean_image')
    registration.inputs.inputspec.subject_id = subject_id
    registration.inputs.inputspec.subjects_dir = subjects_dir
    registration.inputs.inputspec.target_image = target_file

    """Quantify TSNR in each freesurfer ROI
    """
    get_roi_tsnr = MapNode(fs.SegStats(default_color_table=True),
                           iterfield=['in_file'], name='get_aparc_tsnr')
    get_roi_tsnr.inputs.avgwf_txt_file = True
    wf.connect(tsnr, 'tsnr_file', get_roi_tsnr, 'in_file')
    wf.connect(registration, 'outputspec.aparc', get_roi_tsnr, 'segmentation_file')

    """Use :class:`nipype.algorithms.rapidart` to determine which of the
    images in the functional series are outliers based on deviations in
    intensity or movement.
    """

    art = Node(interface=ArtifactDetect(), name="art")
    art.inputs.use_differences = [True, True]
    art.inputs.use_norm = True
    art.inputs.norm_threshold = norm_threshold
    art.inputs.zintensity_threshold = 9
    art.inputs.mask_type = 'spm_global'
    art.inputs.parameter_source = 'NiPy'
    art.inputs.save_plot = False #dbg temporary while matplotlib is not available

    """Here we are connecting all the nodes together. Notice that we add the merge node only if you choose
    to use 4D. Also `get_vox_dims` function is passed along the input volume of normalise to set the optimal
    voxel sizes.
    """

    wf.connect([(realign, art, [('out_file', 'realigned_files')]),
                (realign, art, [('par_file', 'realignment_parameters')]),
                ])

    def selectindex(files, idx):
        import numpy as np
        from nipype.utils.filemanip import filename_to_list, list_to_filename
        return list_to_filename(np.array(filename_to_list(files))[idx].tolist())

    mask = Node(fsl.BET(), name='getmask')
    mask.inputs.mask = True
    wf.connect(calc_median, 'median_file', mask, 'in_file')
    # get segmentation in normalized functional space

    def merge_files(in1, in2):
        out_files = filename_to_list(in1)
        out_files.extend(filename_to_list(in2))
        return out_files

    # filter some noise

    # Compute motion regressors
    motreg = Node(Function(input_names=['motion_params', 'order',
                                        'derivatives'],
                           output_names=['out_files'],
                           function=motion_regressors,
                           imports=imports),
                  name='getmotionregress')
    wf.connect(realign, 'par_file', motreg, 'motion_params')

    # Create a filter to remove motion and art confounds
    createfilter1 = Node(Function(input_names=['motion_params', 'comp_norm',
                                               'outliers', 'detrend_poly'],
                                  output_names=['out_files'],
                                  function=build_filter1,
                                  imports=imports),
                         name='makemotionbasedfilter')
    createfilter1.inputs.detrend_poly = 2
    wf.connect(motreg, 'out_files', createfilter1, 'motion_params')
    wf.connect(art, 'norm_files', createfilter1, 'comp_norm')
    wf.connect(art, 'outlier_files', createfilter1, 'outliers')


    filter1 = MapNode(fsl.GLM(out_f_name='F_mcart.nii.gz',
                              out_pf_name='pF_mcart.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design', 'out_res_name'],
                      name='filtermotion')

    wf.connect(realign, 'out_file', filter1, 'in_file')
    wf.connect(realign, ('out_file', rename, '_filtermotart'),
               filter1, 'out_res_name')
    wf.connect(createfilter1, 'out_files', filter1, 'design')


    createfilter2 = MapNode(Function(input_names=['realigned_file', 'mask_file',
                                                  'num_components',
                                                  'extra_regressors'],
                                     output_names=['out_files'],
                                     function=extract_noise_components,
                                     imports=imports),
                            iterfield=['realigned_file', 'extra_regressors'],
                            name='makecompcorrfilter')
    createfilter2.inputs.num_components = num_components

    wf.connect(createfilter1, 'out_files', createfilter2, 'extra_regressors')
    wf.connect(filter1, 'out_res', createfilter2, 'realigned_file')
    wf.connect(registration, ('outputspec.segmentation_files', selectindex, [0, 2]),
               createfilter2, 'mask_file')


    filter2 = MapNode(fsl.GLM(out_f_name='F.nii.gz',
                              out_pf_name='pF.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design', 'out_res_name'],
                      name='filter_noise_nosmooth')
    wf.connect(filter1, 'out_res', filter2, 'in_file')
    wf.connect(filter1, ('out_res', rename, '_cleaned'),
               filter2, 'out_res_name')
    wf.connect(createfilter2, 'out_files', filter2, 'design')
    wf.connect(mask, 'mask_file', filter2, 'mask')

    bandpass = Node(Function(input_names=['files', 'lowpass_freq',
                                           'highpass_freq', 'fs'],
                              output_names=['out_files'],
                              function=bandpass_filter,
                              imports=imports),
                     name='bandpass_unsmooth')
    bandpass.inputs.fs = 1./TR
    bandpass.inputs.highpass_freq = highpass_freq
    bandpass.inputs.lowpass_freq = lowpass_freq
    wf.connect(filter2, 'out_res', bandpass, 'files')

    """Smooth the functional data using
    :class:`nipype.interfaces.fsl.IsotropicSmooth`.
    """

    smooth = MapNode(interface=fsl.IsotropicSmooth(), name="smooth", iterfield=["in_file"])
    smooth.inputs.fwhm = vol_fwhm

    wf.connect(bandpass, 'out_files', smooth, 'in_file')

    collector = Node(Merge(2), name='collect_streams')
    wf.connect(smooth, 'out_file', collector, 'in1')
    wf.connect(bandpass, 'out_files', collector, 'in2')

    """
    Transform the remaining images. First to anatomical and then to target
    """

    warpall = MapNode(ants.ApplyTransforms(), iterfield=['input_image'],
                      name='warpall')
    warpall.inputs.input_image_type = 3
    warpall.inputs.interpolation = 'Linear'
    warpall.inputs.invert_transform_flags = [False, False]
    warpall.inputs.terminal_output = 'file'
    warpall.inputs.reference_image = target_file
    warpall.inputs.args = '--float'
    warpall.inputs.num_threads = 2
    warpall.plugin_args = {'sbatch_args': '-c%d' % 2}

    # transform to target
    wf.connect(collector, 'out', warpall, 'input_image')
    wf.connect(registration, 'outputspec.transforms', warpall, 'transforms')

    mask_target = Node(fsl.ImageMaths(op_string='-bin'), name='target_mask')

    wf.connect(registration, 'outputspec.anat2target', mask_target, 'in_file')

    maskts = MapNode(fsl.ApplyMask(), iterfield=['in_file'], name='ts_masker')
    wf.connect(warpall, 'output_image', maskts, 'in_file')
    wf.connect(mask_target, 'out_file', maskts, 'mask_file')

    # map to surface
    # extract aparc+aseg ROIs
    # extract subcortical ROIs
    # extract target space ROIs
    # combine subcortical and cortical rois into a single cifti file

    #######
    # Convert aparc to subject functional space

    # Sample the average time series in aparc ROIs
    sampleaparc = MapNode(freesurfer.SegStats(default_color_table=True),
                          iterfield=['in_file', 'summary_file',
                                     'avgwf_txt_file'],
                          name='aparc_ts')
    sampleaparc.inputs.segment_id = ([8] + range(10, 14) + [17, 18, 26, 47] +
                                     range(49, 55) + [58] + range(1001, 1036) +
                                     range(2001, 2036))

    wf.connect(registration, 'outputspec.aparc',
               sampleaparc, 'segmentation_file')
    wf.connect(collector, 'out', sampleaparc, 'in_file')

    def get_names(files, suffix):
        """Generate appropriate names for output files
        """
        from nipype.utils.filemanip import (split_filename, filename_to_list,
                                            list_to_filename)
        import os
        out_names = []
        for filename in files:
            path, name, _ = split_filename(filename)
            out_names.append(os.path.join(path,name + suffix))
        return list_to_filename(out_names)

    wf.connect(collector, ('out', get_names, '_avgwf.txt'),
               sampleaparc, 'avgwf_txt_file')
    wf.connect(collector, ('out', get_names, '_summary.stats'),
               sampleaparc, 'summary_file')

    # Sample the time series onto the surface of the target surface. Performs
    # sampling into left and right hemisphere
    target = Node(IdentityInterface(fields=['target_subject']), name='target')
    target.iterables = ('target_subject', filename_to_list(target_subject))

    samplerlh = MapNode(freesurfer.SampleToSurface(),
                        iterfield=['source_file'],
                        name='sampler_lh')
    samplerlh.inputs.sampling_method = "average"
    samplerlh.inputs.sampling_range = (0.1, 0.9, 0.1)
    samplerlh.inputs.sampling_units = "frac"
    samplerlh.inputs.interp_method = "trilinear"
    samplerlh.inputs.smooth_surf = surf_fwhm
    #samplerlh.inputs.cortex_mask = True
    samplerlh.inputs.out_type = 'niigz'
    samplerlh.inputs.subjects_dir = subjects_dir

    samplerrh = samplerlh.clone('sampler_rh')

    samplerlh.inputs.hemi = 'lh'
    wf.connect(collector, 'out', samplerlh, 'source_file')
    wf.connect(registration, 'outputspec.out_reg_file', samplerlh, 'reg_file')
    wf.connect(target, 'target_subject', samplerlh, 'target_subject')

    samplerrh.set_input('hemi', 'rh')
    wf.connect(collector, 'out', samplerrh, 'source_file')
    wf.connect(registration, 'outputspec.out_reg_file', samplerrh, 'reg_file')
    wf.connect(target, 'target_subject', samplerrh, 'target_subject')

    # Combine left and right hemisphere to text file
    combiner = MapNode(Function(input_names=['left', 'right'],
                                output_names=['out_file'],
                                function=combine_hemi,
                                imports=imports),
                       iterfield=['left', 'right'],
                       name="combiner")
    wf.connect(samplerlh, 'out_file', combiner, 'left')
    wf.connect(samplerrh, 'out_file', combiner, 'right')

    # Sample the time series file for each subcortical roi
    ts2txt = MapNode(Function(input_names=['timeseries_file', 'label_file',
                                           'indices'],
                              output_names=['out_file'],
                              function=extract_subrois,
                              imports=imports),
                     iterfield=['timeseries_file'],
                     name='getsubcortts')
    ts2txt.inputs.indices = [8] + range(10, 14) + [17, 18, 26, 47] +\
                            range(49, 55) + [58]
    ts2txt.inputs.label_file = \
        os.path.abspath(('OASIS-TRT-20_jointfusion_DKT31_CMA_labels_in_MNI152_'
                         '2mm_v2.nii.gz'))
    wf.connect(maskts, 'out_file', ts2txt, 'timeseries_file')

    ######

    substitutions = [('_target_subject_', ''),
                     ('_filtermotart_cleaned_bp_trans_masked', ''),
                     ('_filtermotart_cleaned_bp', ''),
                     ]
    substitutions += [("_smooth%d" % i,"") for i in range(11)[::-1]]
    substitutions += [("_ts_masker%d" % i,"") for i in range(11)[::-1]]
    substitutions += [("_getsubcortts%d" % i,"") for i in range(11)[::-1]]
    substitutions += [("_combiner%d" % i,"") for i in range(11)[::-1]]
    substitutions += [("_filtermotion%d" % i,"") for i in range(11)[::-1]]
    substitutions += [("_filter_noise_nosmooth%d" % i,"") for i in range(11)[::-1]]
    substitutions += [("_makecompcorfilter%d" % i,"") for i in range(11)[::-1]]
    substitutions += [("_get_aparc_tsnr%d/" % i, "run%d_" % (i + 1)) for i in range(11)[::-1]]

    substitutions += [("T1_out_brain_pve_0_maths_warped", "compcor_csf"),
                      ("T1_out_brain_pve_1_maths_warped", "compcor_gm"),
                      ("T1_out_brain_pve_2_maths_warped", "compcor_wm"),
                      ("output_warped_image_maths", "target_brain_mask"),
                      ("median_brain_mask", "native_brain_mask"),
                      ("corr_", "")]

    regex_subs = [('_combiner.*/sar', '/smooth/'),
                  ('_combiner.*/ar', '/unsmooth/'),
                  ('_aparc_ts.*/sar', '/smooth/'),
                  ('_aparc_ts.*/ar', '/unsmooth/'),
                  ('_getsubcortts.*/sar', '/smooth/'),
                  ('_getsubcortts.*/ar', '/unsmooth/'),
                  ('series/sar', 'series/smooth/'),
                  ('series/ar', 'series/unsmooth/'),
                  ('_inverse_transform./', ''),
                  ]
    # Save the relevant data into an output directory
    datasink = Node(interface=DataSink(), name="datasink")
    datasink.inputs.base_directory = sink_directory
    #datasink.inputs.container = subject_id
    datasink.inputs.substitutions = substitutions
    datasink.inputs.regexp_substitutions = regex_subs #(r'(/_.*(\d+/))', r'/run\2')
    wf.connect(realign, 'par_file', datasink, 'qa.motion')
    wf.connect(art, 'norm_files', datasink, 'qa.art.@norm')
    wf.connect(art, 'intensity_files', datasink, 'qa.art.@intensity')
    wf.connect(art, 'outlier_files', datasink, 'qa.art.@outlier_files')
    wf.connect(registration, 'outputspec.segmentation_files', datasink, 'mask_files')
    wf.connect(registration, 'outputspec.anat2target', datasink, 'qa.ants')
    wf.connect(mask, 'mask_file', datasink, 'mask_files.@brainmask')
    wf.connect(mask_target, 'out_file', datasink, 'mask_files.target')
    wf.connect(filter1, 'out_f', datasink, 'qa.compmaps.@mc_F')
    wf.connect(filter1, 'out_pf', datasink, 'qa.compmaps.@mc_pF')
    wf.connect(filter2, 'out_f', datasink, 'qa.compmaps')
    wf.connect(filter2, 'out_pf', datasink, 'qa.compmaps.@p')
    wf.connect(registration, 'outputspec.min_cost_file', datasink, 'qa.mincost')
    wf.connect(tsnr, 'tsnr_file', datasink, 'qa.tsnr.@map')
    wf.connect([(get_roi_tsnr, datasink, [('avgwf_txt_file', 'qa.tsnr'),
                                          ('summary_file', 'qa.tsnr.@summary')])])

    wf.connect(bandpass, 'out_files', datasink, 'timeseries.@bandpassed')
    wf.connect(smooth, 'out_file', datasink, 'timeseries.@smoothed')
    wf.connect(createfilter1, 'out_files',
               datasink, 'regress.@regressors')
    wf.connect(createfilter2, 'out_files',
               datasink, 'regress.@compcorr')
    wf.connect(maskts, 'out_file', datasink, 'timeseries.target')
    wf.connect(sampleaparc, 'summary_file',
               datasink, 'parcellations.aparc')
    wf.connect(sampleaparc, 'avgwf_txt_file',
               datasink, 'parcellations.aparc.@avgwf')
    wf.connect(ts2txt, 'out_file',
               datasink, 'parcellations.grayo.@subcortical')

    datasink2 = Node(interface=DataSink(), name="datasink2")
    datasink2.inputs.base_directory = sink_directory
    #datasink2.inputs.container = subject_id
    datasink2.inputs.substitutions = substitutions
    datasink2.inputs.regexp_substitutions = regex_subs #(r'(/_.*(\d+/))', r'/run\2')
    wf.connect(combiner, 'out_file',
               datasink2, 'parcellations.grayo.@surface')
    return wf
def group_multregress_openfmri(dataset_dir,
                               model_id=None,
                               task_id=None,
                               l1output_dir=None,
                               out_dir=None,
                               no_reversal=False,
                               plugin=None,
                               plugin_args=None,
                               flamemodel='flame1',
                               nonparametric=False,
                               use_spm=False):

    meta_workflow = Workflow(name='mult_regress')
    meta_workflow.base_dir = work_dir
    for task in task_id:
        task_name = get_taskname(dataset_dir, task)
        cope_ids = l1_contrasts_num(model_id, task_name, dataset_dir)
        regressors_needed, contrasts, groups, subj_list = get_sub_vars(
            dataset_dir, task_name, model_id)
        for idx, contrast in enumerate(contrasts):
            wk = Workflow(name='model_%03d_task_%03d_contrast_%s' %
                          (model_id, task, contrast[0][0]))

            info = Node(util.IdentityInterface(
                fields=['model_id', 'task_id', 'dataset_dir', 'subj_list']),
                        name='infosource')
            info.inputs.model_id = model_id
            info.inputs.task_id = task
            info.inputs.dataset_dir = dataset_dir

            dg = Node(DataGrabber(infields=['model_id', 'task_id', 'cope_id'],
                                  outfields=['copes', 'varcopes']),
                      name='grabber')
            dg.inputs.template = os.path.join(
                l1output_dir,
                'model%03d/task%03d/%s/%scopes/%smni/%scope%02d.nii%s')
            if use_spm:
                dg.inputs.template_args['copes'] = [[
                    'model_id', 'task_id', subj_list, '', 'spm/', '',
                    'cope_id', ''
                ]]
                dg.inputs.template_args['varcopes'] = [[
                    'model_id', 'task_id', subj_list, 'var', 'spm/', 'var',
                    'cope_id', '.gz'
                ]]
            else:
                dg.inputs.template_args['copes'] = [[
                    'model_id', 'task_id', subj_list, '', '', '', 'cope_id',
                    '.gz'
                ]]
                dg.inputs.template_args['varcopes'] = [[
                    'model_id', 'task_id', subj_list, 'var', '', 'var',
                    'cope_id', '.gz'
                ]]
            dg.iterables = ('cope_id', cope_ids)
            dg.inputs.sort_filelist = False

            wk.connect(info, 'model_id', dg, 'model_id')
            wk.connect(info, 'task_id', dg, 'task_id')

            model = Node(MultipleRegressDesign(), name='l2model')
            model.inputs.groups = groups
            model.inputs.contrasts = contrasts[idx]
            model.inputs.regressors = regressors_needed[idx]

            mergecopes = Node(Merge(dimension='t'), name='merge_copes')
            wk.connect(dg, 'copes', mergecopes, 'in_files')

            if flamemodel != 'ols':
                mergevarcopes = Node(Merge(dimension='t'),
                                     name='merge_varcopes')
                wk.connect(dg, 'varcopes', mergevarcopes, 'in_files')

            mask_file = fsl.Info.standard_image(
                'MNI152_T1_2mm_brain_mask.nii.gz')
            flame = Node(FLAMEO(), name='flameo')
            flame.inputs.mask_file = mask_file
            flame.inputs.run_mode = flamemodel
            #flame.inputs.infer_outliers = True

            wk.connect(model, 'design_mat', flame, 'design_file')
            wk.connect(model, 'design_con', flame, 't_con_file')
            wk.connect(mergecopes, 'merged_file', flame, 'cope_file')
            if flamemodel != 'ols':
                wk.connect(mergevarcopes, 'merged_file', flame,
                           'var_cope_file')
            wk.connect(model, 'design_grp', flame, 'cov_split_file')

            if nonparametric:
                palm = Node(Function(input_names=[
                    'cope_file', 'design_file', 'contrast_file', 'group_file',
                    'mask_file', 'cluster_threshold'
                ],
                                     output_names=['palm_outputs'],
                                     function=run_palm),
                            name='palm')
                palm.inputs.cluster_threshold = 3.09
                palm.inputs.mask_file = mask_file
                palm.plugin_args = {
                    'sbatch_args': '-p om_all_nodes -N1 -c2 --mem=10G',
                    'overwrite': True
                }
                wk.connect(model, 'design_mat', palm, 'design_file')
                wk.connect(model, 'design_con', palm, 'contrast_file')
                wk.connect(mergecopes, 'merged_file', palm, 'cope_file')
                wk.connect(model, 'design_grp', palm, 'group_file')

            smoothest = Node(SmoothEstimate(), name='smooth_estimate')
            wk.connect(flame, 'zstats', smoothest, 'zstat_file')
            smoothest.inputs.mask_file = mask_file

            cluster = Node(Cluster(), name='cluster')
            wk.connect(smoothest, 'dlh', cluster, 'dlh')
            wk.connect(smoothest, 'volume', cluster, 'volume')
            cluster.inputs.connectivity = 26
            cluster.inputs.threshold = 2.3
            cluster.inputs.pthreshold = 0.05
            cluster.inputs.out_threshold_file = True
            cluster.inputs.out_index_file = True
            cluster.inputs.out_localmax_txt_file = True

            wk.connect(flame, 'zstats', cluster, 'in_file')

            ztopval = Node(ImageMaths(op_string='-ztop', suffix='_pval'),
                           name='z2pval')
            wk.connect(flame, 'zstats', ztopval, 'in_file')

            sinker = Node(DataSink(), name='sinker')
            sinker.inputs.base_directory = os.path.join(
                out_dir, 'task%03d' % task, contrast[0][0])
            sinker.inputs.substitutions = [('_cope_id', 'contrast'),
                                           ('_maths_', '_reversed_')]

            wk.connect(flame, 'zstats', sinker, 'stats')
            wk.connect(cluster, 'threshold_file', sinker, 'stats.@thr')
            wk.connect(cluster, 'index_file', sinker, 'stats.@index')
            wk.connect(cluster, 'localmax_txt_file', sinker, 'stats.@localmax')
            if nonparametric:
                wk.connect(palm, 'palm_outputs', sinker, 'stats.palm')

            if not no_reversal:
                zstats_reverse = Node(BinaryMaths(), name='zstats_reverse')
                zstats_reverse.inputs.operation = 'mul'
                zstats_reverse.inputs.operand_value = -1
                wk.connect(flame, 'zstats', zstats_reverse, 'in_file')

                cluster2 = cluster.clone(name='cluster2')
                wk.connect(smoothest, 'dlh', cluster2, 'dlh')
                wk.connect(smoothest, 'volume', cluster2, 'volume')
                wk.connect(zstats_reverse, 'out_file', cluster2, 'in_file')

                ztopval2 = ztopval.clone(name='ztopval2')
                wk.connect(zstats_reverse, 'out_file', ztopval2, 'in_file')

                wk.connect(zstats_reverse, 'out_file', sinker, 'stats.@neg')
                wk.connect(cluster2, 'threshold_file', sinker,
                           'stats.@neg_thr')
                wk.connect(cluster2, 'index_file', sinker, 'stats.@neg_index')
                wk.connect(cluster2, 'localmax_txt_file', sinker,
                           'stats.@neg_localmax')
            meta_workflow.add_nodes([wk])
    return meta_workflow
def create_workflow(files,
                    subject_id,
                    n_vol=0,
                    despike=True,
                    TR=None,
                    slice_times=None,
                    slice_thickness=None,
                    fieldmap_images=[],
                    norm_threshold=1,
                    num_components=6,
                    vol_fwhm=None,
                    surf_fwhm=None,
                    lowpass_freq=-1,
                    highpass_freq=-1,
                    sink_directory=os.getcwd(),
                    FM_TEdiff=2.46,
                    FM_sigma=2,
                    FM_echo_spacing=.7,
                    target_subject=['fsaverage3', 'fsaverage4'],
                    name='resting'):

    wf = Workflow(name=name)

    # Skip starting volumes
    remove_vol = MapNode(fsl.ExtractROI(t_min=n_vol, t_size=-1),
                         iterfield=['in_file'],
                         name="remove_volumes")
    remove_vol.inputs.in_file = files

    # Run AFNI's despike. This is always run, however, whether this is fed to
    # realign depends on the input configuration
    despiker = MapNode(afni.Despike(outputtype='NIFTI_GZ'),
                       iterfield=['in_file'],
                       name='despike')
    #despiker.plugin_args = {'qsub_args': '-l nodes=1:ppn='}

    wf.connect(remove_vol, 'roi_file', despiker, 'in_file')

    # Run Nipy joint slice timing and realignment algorithm
    realign = Node(nipy.SpaceTimeRealigner(), name='realign')
    realign.inputs.tr = TR
    realign.inputs.slice_times = slice_times
    realign.inputs.slice_info = 2

    if despike:
        wf.connect(despiker, 'out_file', realign, 'in_file')
    else:
        wf.connect(remove_vol, 'roi_file', realign, 'in_file')

    # Comute TSNR on realigned data regressing polynomials upto order 2
    tsnr = MapNode(TSNR(regress_poly=2), iterfield=['in_file'], name='tsnr')
    wf.connect(realign, 'out_file', tsnr, 'in_file')

    # Compute the median image across runs
    calc_median = Node(Function(input_names=['in_files'],
                                output_names=['median_file'],
                                function=median,
                                imports=imports),
                       name='median')
    wf.connect(tsnr, 'detrended_file', calc_median, 'in_files')

    # Coregister the median to the surface
    register = Node(freesurfer.BBRegister(),
                    name='bbregister')
    register.inputs.subject_id = subject_id
    register.inputs.init = 'fsl'
    register.inputs.contrast_type = 't2'
    register.inputs.out_fsl_file = True
    register.inputs.epi_mask = True

    # Compute fieldmaps and unwarp using them
    if fieldmap_images:
        fieldmap = Node(interface=EPIDeWarp(), name='fieldmap_unwarp')
        fieldmap.inputs.tediff = FM_TEdiff
        fieldmap.inputs.esp = FM_echo_spacing
        fieldmap.inputs.sigma = FM_sigma
        fieldmap.inputs.mag_file = fieldmap_images[0]
        fieldmap.inputs.dph_file = fieldmap_images[1]
        wf.connect(calc_median, 'median_file', fieldmap, 'exf_file')

        dewarper = MapNode(interface=fsl.FUGUE(), iterfield=['in_file'],
                           name='dewarper')
        wf.connect(tsnr, 'detrended_file', dewarper, 'in_file')
        wf.connect(fieldmap, 'exf_mask', dewarper, 'mask_file')
        wf.connect(fieldmap, 'vsm_file', dewarper, 'shift_in_file')
        wf.connect(fieldmap, 'exfdw', register, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', register, 'source_file')

    # Get the subject's freesurfer source directory
    fssource = Node(FreeSurferSource(),
                    name='fssource')
    fssource.inputs.subject_id = subject_id
    fssource.inputs.subjects_dir = os.environ['SUBJECTS_DIR']

    # Extract wm+csf, brain masks by eroding freesurfer lables and then
    # transform the masks into the space of the median
    wmcsf = Node(freesurfer.Binarize(), name='wmcsfmask')
    mask = wmcsf.clone('anatmask')
    wmcsftransform = Node(freesurfer.ApplyVolTransform(inverse=True,
                                                       interp='nearest'),
                          name='wmcsftransform')
    wmcsftransform.inputs.subjects_dir = os.environ['SUBJECTS_DIR']
    wmcsf.inputs.wm_ven_csf = True
    wmcsf.inputs.match = [4, 5, 14, 15, 24, 31, 43, 44, 63]
    wmcsf.inputs.binary_file = 'wmcsf.nii.gz'
    wmcsf.inputs.erode = int(np.ceil(slice_thickness))
    wf.connect(fssource, ('aparc_aseg', get_aparc_aseg), wmcsf, 'in_file')
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', wmcsftransform, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', wmcsftransform, 'source_file')
    wf.connect(register, 'out_reg_file', wmcsftransform, 'reg_file')
    wf.connect(wmcsf, 'binary_file', wmcsftransform, 'target_file')

    mask.inputs.binary_file = 'mask.nii.gz'
    mask.inputs.dilate = int(np.ceil(slice_thickness)) + 1
    mask.inputs.erode = int(np.ceil(slice_thickness))
    mask.inputs.min = 0.5
    wf.connect(fssource, ('aparc_aseg', get_aparc_aseg), mask, 'in_file')
    masktransform = wmcsftransform.clone("masktransform")
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', masktransform, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', masktransform, 'source_file')
    wf.connect(register, 'out_reg_file', masktransform, 'reg_file')
    wf.connect(mask, 'binary_file', masktransform, 'target_file')

    # Compute Art outliers
    art = Node(interface=ArtifactDetect(use_differences=[True, False],
                                        use_norm=True,
                                        norm_threshold=norm_threshold,
                                        zintensity_threshold=3,
                                        parameter_source='NiPy',
                                        bound_by_brainmask=True,
                                        save_plot=False,
                                        mask_type='file'),
               name="art")
    if fieldmap_images:
        wf.connect(dewarper, 'unwarped_file', art, 'realigned_files')
    else:
        wf.connect(tsnr, 'detrended_file', art, 'realigned_files')
    wf.connect(realign, 'par_file',
               art, 'realignment_parameters')
    wf.connect(masktransform, 'transformed_file', art, 'mask_file')

    # Compute motion regressors
    motreg = Node(Function(input_names=['motion_params', 'order',
                                        'derivatives'],
                           output_names=['out_files'],
                           function=motion_regressors,
                           imports=imports),
                  name='getmotionregress')
    wf.connect(realign, 'par_file', motreg, 'motion_params')

    # Create a filter to remove motion and art confounds
    createfilter1 = Node(Function(input_names=['motion_params', 'comp_norm',
                                               'outliers'],
                                  output_names=['out_files'],
                                  function=build_filter1,
                                  imports=imports),
                         name='makemotionbasedfilter')
    wf.connect(motreg, 'out_files', createfilter1, 'motion_params')
    wf.connect(art, 'norm_files', createfilter1, 'comp_norm')
    wf.connect(art, 'outlier_files', createfilter1, 'outliers')

    # Filter the motion and art confounds
    filter1 = MapNode(fsl.GLM(out_res_name='timeseries.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design'],
                      name='filtermotion')
    if fieldmap_images:
        wf.connect(dewarper, 'unwarped_file', filter1, 'in_file')
    else:
        wf.connect(tsnr, 'detrended_file', filter1, 'in_file')
    wf.connect(createfilter1, 'out_files', filter1, 'design')
    wf.connect(masktransform, 'transformed_file', filter1, 'mask')

    # Create a filter to remove noise components based on white matter and CSF
    createfilter2 = MapNode(Function(input_names=['realigned_file', 'mask_file',
                                                  'num_components'],
                                     output_names=['out_files'],
                                     function=extract_noise_components,
                                     imports=imports),
                            iterfield=['realigned_file'],
                            name='makecompcorrfilter')
    createfilter2.inputs.num_components = num_components
    wf.connect(filter1, 'out_res', createfilter2, 'realigned_file')
    wf.connect(masktransform, 'transformed_file', createfilter2, 'mask_file')

    # Filter noise components
    filter2 = MapNode(fsl.GLM(out_res_name='timeseries_cleaned.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design'],
                      name='filtercompcorr')
    wf.connect(filter1, 'out_res', filter2, 'in_file')
    wf.connect(createfilter2, 'out_files', filter2, 'design')
    wf.connect(masktransform, 'transformed_file', filter2, 'mask')

    # Smoothing using surface and volume smoothing
    smooth = MapNode(freesurfer.Smooth(),
                     iterfield=['in_file'],
                     name='smooth')
    smooth.inputs.proj_frac_avg = (0.1, 0.9, 0.1)
    if surf_fwhm is None:
        surf_fwhm = 5 * slice_thickness
    smooth.inputs.surface_fwhm = surf_fwhm
    if vol_fwhm is None:
        vol_fwhm = 2 * slice_thickness
    smooth.inputs.vol_fwhm = vol_fwhm
    wf.connect(filter2, 'out_res',  smooth, 'in_file')
    wf.connect(register, 'out_reg_file', smooth, 'reg_file')

    # Bandpass filter the data
    bandpass = MapNode(fsl.TemporalFilter(),
                       iterfield=['in_file'],
                       name='bandpassfilter')
    if highpass_freq < 0:
            bandpass.inputs.highpass_sigma = -1
    else:
            bandpass.inputs.highpass_sigma = 1. / (2 * TR * highpass_freq)
    if lowpass_freq < 0:
            bandpass.inputs.lowpass_sigma = -1
    else:
            bandpass.inputs.lowpass_sigma = 1. / (2 * TR * lowpass_freq)
    wf.connect(smooth, 'smoothed_file', bandpass, 'in_file')

    # Convert aparc to subject functional space
    aparctransform = wmcsftransform.clone("aparctransform")
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', aparctransform, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', aparctransform, 'source_file')
    wf.connect(register, 'out_reg_file', aparctransform, 'reg_file')
    wf.connect(fssource, ('aparc_aseg', get_aparc_aseg),
               aparctransform, 'target_file')

    # Sample the average time series in aparc ROIs
    sampleaparc = MapNode(freesurfer.SegStats(avgwf_txt_file=True,
                                              default_color_table=True),
                          iterfield=['in_file'],
                          name='aparc_ts')
    sampleaparc.inputs.segment_id = ([8] + range(10, 14) + [17, 18, 26, 47] +
                                     range(49, 55) + [58] + range(1001, 1036) +
                                     range(2001, 2036))

    wf.connect(aparctransform, 'transformed_file',
               sampleaparc, 'segmentation_file')
    wf.connect(bandpass, 'out_file', sampleaparc, 'in_file')

    # Sample the time series onto the surface of the target surface. Performs
    # sampling into left and right hemisphere
    target = Node(IdentityInterface(fields=['target_subject']), name='target')
    target.iterables = ('target_subject', filename_to_list(target_subject))

    samplerlh = MapNode(freesurfer.SampleToSurface(),
                        iterfield=['source_file'],
                        name='sampler_lh')
    samplerlh.inputs.sampling_method = "average"
    samplerlh.inputs.sampling_range = (0.1, 0.9, 0.1)
    samplerlh.inputs.sampling_units = "frac"
    samplerlh.inputs.interp_method = "trilinear"
    #samplerlh.inputs.cortex_mask = True
    samplerlh.inputs.out_type = 'niigz'
    samplerlh.inputs.subjects_dir = os.environ['SUBJECTS_DIR']

    samplerrh = samplerlh.clone('sampler_rh')

    samplerlh.inputs.hemi = 'lh'
    wf.connect(bandpass, 'out_file', samplerlh, 'source_file')
    wf.connect(register, 'out_reg_file', samplerlh, 'reg_file')
    wf.connect(target, 'target_subject', samplerlh, 'target_subject')

    samplerrh.set_input('hemi', 'rh')
    wf.connect(bandpass, 'out_file', samplerrh, 'source_file')
    wf.connect(register, 'out_reg_file', samplerrh, 'reg_file')
    wf.connect(target, 'target_subject', samplerrh, 'target_subject')

    # Combine left and right hemisphere to text file
    combiner = MapNode(Function(input_names=['left', 'right'],
                                output_names=['out_file'],
                                function=combine_hemi,
                                imports=imports),
                       iterfield=['left', 'right'],
                       name="combiner")
    wf.connect(samplerlh, 'out_file', combiner, 'left')
    wf.connect(samplerrh, 'out_file', combiner, 'right')

    # Compute registration between the subject's structural and MNI template
    # This is currently set to perform a very quick registration. However, the
    # registration can be made significantly more accurate for cortical
    # structures by increasing the number of iterations
    # All parameters are set using the example from:
    # https://github.com/stnava/ANTs/blob/master/Scripts/newAntsExample.sh
    reg = Node(ants.Registration(), name='antsRegister')
    reg.inputs.output_transform_prefix = "output_"
    reg.inputs.transforms = ['Translation', 'Rigid', 'Affine', 'SyN']
    reg.inputs.transform_parameters = [(0.1,), (0.1,), (0.1,), (0.2, 3.0, 0.0)]
    # reg.inputs.number_of_iterations = ([[10000, 111110, 11110]]*3 +
    #                                    [[100, 50, 30]])
    reg.inputs.number_of_iterations = [[100, 100, 100]] * 3 + [[100, 20, 10]]
    reg.inputs.dimension = 3
    reg.inputs.write_composite_transform = True
    reg.inputs.collapse_output_transforms = False
    reg.inputs.metric = ['Mattes'] * 3 + [['Mattes', 'CC']]
    reg.inputs.metric_weight = [1] * 3 + [[0.5, 0.5]]
    reg.inputs.radius_or_number_of_bins = [32] * 3 + [[32, 4]]
    reg.inputs.sampling_strategy = ['Regular'] * 3 + [[None, None]]
    reg.inputs.sampling_percentage = [0.3] * 3 + [[None, None]]
    reg.inputs.convergence_threshold = [1.e-8] * 3 + [-0.01]
    reg.inputs.convergence_window_size = [20] * 3 + [5]
    reg.inputs.smoothing_sigmas = [[4, 2, 1]] * 3 + [[1, 0.5, 0]]
    reg.inputs.sigma_units = ['vox'] * 4
    reg.inputs.shrink_factors = [[6, 4, 2]] + [[3, 2, 1]]*2 + [[4, 2, 1]]
    reg.inputs.use_estimate_learning_rate_once = [True] * 4
    reg.inputs.use_histogram_matching = [False] * 3 + [True]
    reg.inputs.output_warped_image = 'output_warped_image.nii.gz'
    reg.inputs.fixed_image = \
        os.path.abspath('OASIS-30_Atropos_template_in_MNI152_2mm.nii.gz')
    reg.inputs.num_threads = 4
    reg.plugin_args = {'qsub_args': '-l nodes=1:ppn=4'}

    # Convert T1.mgz to nifti for using with ANTS
    convert = Node(freesurfer.MRIConvert(out_type='niigz'), name='convert2nii')
    wf.connect(fssource, 'T1', convert, 'in_file')

    # Mask the T1.mgz file with the brain mask computed earlier
    maskT1 = Node(fsl.BinaryMaths(operation='mul'), name='maskT1')
    wf.connect(mask, 'binary_file', maskT1, 'operand_file')
    wf.connect(convert, 'out_file', maskT1, 'in_file')
    wf.connect(maskT1, 'out_file', reg, 'moving_image')

    # Convert the BBRegister transformation to ANTS ITK format
    convert2itk = MapNode(C3dAffineTool(),
                          iterfield=['transform_file', 'source_file'],
                          name='convert2itk')
    convert2itk.inputs.fsl2ras = True
    convert2itk.inputs.itk_transform = True
    wf.connect(register, 'out_fsl_file', convert2itk, 'transform_file')
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', convert2itk, 'source_file')
    else:
        wf.connect(calc_median, 'median_file', convert2itk, 'source_file')
    wf.connect(convert, 'out_file', convert2itk, 'reference_file')

    # Concatenate the affine and ants transforms into a list
    pickfirst = lambda x: x[0]
    merge = MapNode(Merge(2), iterfield=['in2'], name='mergexfm')
    wf.connect(convert2itk, 'itk_transform', merge, 'in2')
    wf.connect(reg, ('composite_transform', pickfirst), merge, 'in1')

    # Apply the combined transform to the time series file
    sample2mni = MapNode(ants.ApplyTransforms(),
                         iterfield=['input_image', 'transforms'],
                         name='sample2mni')
    sample2mni.inputs.input_image_type = 3
    sample2mni.inputs.interpolation = 'BSpline'
    sample2mni.inputs.invert_transform_flags = [False, False]
    sample2mni.inputs.reference_image = \
        os.path.abspath('OASIS-30_Atropos_template_in_MNI152_2mm.nii.gz')
    sample2mni.inputs.terminal_output = 'file'
    wf.connect(bandpass, 'out_file', sample2mni, 'input_image')
    wf.connect(merge, 'out', sample2mni, 'transforms')

    # Sample the time series file for each subcortical roi
    ts2txt = MapNode(Function(input_names=['timeseries_file', 'label_file',
                                           'indices'],
                              output_names=['out_file'],
                              function=extract_subrois,
                              imports=imports),
                     iterfield=['timeseries_file'],
                     name='getsubcortts')
    ts2txt.inputs.indices = [8] + range(10, 14) + [17, 18, 26, 47] +\
                            range(49, 55) + [58]
    ts2txt.inputs.label_file = \
        os.path.abspath(('OASIS-TRT-20_jointfusion_DKT31_CMA_labels_in_MNI152_'
                         '2mm.nii.gz'))
    wf.connect(sample2mni, 'output_image', ts2txt, 'timeseries_file')

    # Save the relevant data into an output directory
    datasink = Node(interface=DataSink(), name="datasink")
    datasink.inputs.base_directory = sink_directory
    datasink.inputs.container = subject_id
    datasink.inputs.substitutions = [('_target_subject_', '')]
    datasink.inputs.regexp_substitutions = (r'(/_.*(\d+/))', r'/run\2')
    wf.connect(despiker, 'out_file', datasink, 'resting.qa.despike')
    wf.connect(realign, 'par_file', datasink, 'resting.qa.motion')
    wf.connect(tsnr, 'tsnr_file', datasink, 'resting.qa.tsnr')
    wf.connect(tsnr, 'mean_file', datasink, 'resting.qa.tsnr.@mean')
    wf.connect(tsnr, 'stddev_file', datasink, 'resting.qa.@tsnr_stddev')
    if fieldmap_images:
        wf.connect(fieldmap, 'exf_mask', datasink, 'resting.reference')
    else:
        wf.connect(calc_median, 'median_file', datasink, 'resting.reference')
    wf.connect(art, 'norm_files', datasink, 'resting.qa.art.@norm')
    wf.connect(art, 'intensity_files', datasink, 'resting.qa.art.@intensity')
    wf.connect(art, 'outlier_files', datasink, 'resting.qa.art.@outlier_files')
    wf.connect(mask, 'binary_file', datasink, 'resting.mask')
    wf.connect(masktransform, 'transformed_file',
               datasink, 'resting.mask.@transformed_file')
    wf.connect(register, 'out_reg_file', datasink, 'resting.registration.bbreg')
    wf.connect(reg, ('composite_transform', pickfirst),
               datasink, 'resting.registration.ants')
    wf.connect(register, 'min_cost_file',
               datasink, 'resting.qa.bbreg.@mincost')
    wf.connect(smooth, 'smoothed_file', datasink, 'resting.timeseries.fullpass')
    wf.connect(bandpass, 'out_file', datasink, 'resting.timeseries.bandpassed')
    wf.connect(sample2mni, 'output_image', datasink, 'resting.timeseries.mni')
    wf.connect(createfilter1, 'out_files',
               datasink, 'resting.regress.@regressors')
    wf.connect(createfilter2, 'out_files',
               datasink, 'resting.regress.@compcorr')
    wf.connect(sampleaparc, 'summary_file',
               datasink, 'resting.parcellations.aparc')
    wf.connect(sampleaparc, 'avgwf_txt_file',
               datasink, 'resting.parcellations.aparc.@avgwf')
    wf.connect(ts2txt, 'out_file',
               datasink, 'resting.parcellations.grayo.@subcortical')
    datasink2 = Node(interface=DataSink(), name="datasink2")
    datasink2.inputs.base_directory = sink_directory
    datasink2.inputs.container = subject_id
    datasink2.inputs.substitutions = [('_target_subject_', '')]
    datasink2.inputs.regexp_substitutions = (r'(/_.*(\d+/))', r'/run\2')
    wf.connect(combiner, 'out_file',
               datasink2, 'resting.parcellations.grayo.@surface')
    return wf
Exemple #13
0
info = dict(T1=[['subject_id']])

infosource = Node(IdentityInterface(fields=['subject_id']), name='infosource')
infosource.iterables = ('subject_id', sids)

# Create a datasource node to get the T1 file
datasource = Node(DataGrabber(infields=['subject_id'],outfields=info.keys()),name = 'datasource')
datasource.inputs.template = '%s/%s'
datasource.inputs.base_directory = os.path.abspath('/home/data/madlab/data/mri/seqtrd/')
datasource.inputs.field_template = dict(T1='%s/anatomy/T1_*.nii.gz')
datasource.inputs.template_args = info
datasource.inputs.sort_filelist = True

reconall_node = Node(ReconAll(), name='reconall_node')
reconall_node.inputs.openmp = 2
reconall_node.inputs.subjects_dir = os.environ['SUBJECTS_DIR']
reconall_node.inputs.terminal_output = 'allatonce'
reconall_node.plugin_args={'bsub_args': ('-q PQ_madlab -n 2'), 'overwrite': True}

wf = Workflow(name='fsrecon')

wf.connect(infosource, 'subject_id', datasource, 'subject_id')
wf.connect(infosource, 'subject_id', reconall_node, 'subject_id')
wf.connect(datasource, 'T1', reconall_node, 'T1_files')

wf.base_dir = os.path.abspath('/scratch/madlab/surfaces/seqtrd')
#wf.config['execution']['job_finished_timeout'] = 65

wf.run(plugin='LSF', plugin_args={'bsub_args': ('-q PQ_madlab')})

Exemple #14
0
# Create a datasource node to get the T1 file
datasource = Node(DataGrabber(infields=['subject_id'], outfields=info.keys()),
                  name='datasource')
datasource.inputs.template = '%s/%s'
datasource.inputs.base_directory = os.path.abspath(
    '/home/data/madlab/data/mri/seqtrd/')
datasource.inputs.field_template = dict(T1='%s/anatomy/T1_*.nii.gz')
datasource.inputs.template_args = info
datasource.inputs.sort_filelist = True

reconall_node = Node(ReconAll(), name='reconall_node')
reconall_node.inputs.openmp = 2
reconall_node.inputs.subjects_dir = os.environ['SUBJECTS_DIR']
reconall_node.inputs.terminal_output = 'allatonce'
reconall_node.plugin_args = {
    'bsub_args': ('-q PQ_madlab -n 2'),
    'overwrite': True
}

wf = Workflow(name='fsrecon')

wf.connect(infosource, 'subject_id', datasource, 'subject_id')
wf.connect(infosource, 'subject_id', reconall_node, 'subject_id')
wf.connect(datasource, 'T1', reconall_node, 'T1_files')

wf.base_dir = os.path.abspath('/scratch/madlab/surfaces/seqtrd')
#wf.config['execution']['job_finished_timeout'] = 65

wf.run(plugin='LSF', plugin_args={'bsub_args': ('-q PQ_madlab')})
Exemple #15
0
def create_machine_learning_workflow(
    name="CreateEdgeProbabilityMap", resample=True, plugin_args=None
):
    """
    This function...

    :param name:
    :param resample:
    :param plugin_args:
    :return:
    """
    workflow = Workflow(name)
    input_spec = Node(
        IdentityInterface(
            [
                "rho",
                "phi",
                "theta",
                "posteriors",
                "t1_file",
                "acpc_transform",
                "gm_classifier_file",
                "wm_classifier_file",
            ]
        ),
        name="input_spec",
    )

    predict_edge_probability = Node(
        PredictEdgeProbability(), name="PredictEdgeProbability"
    )
    if plugin_args:
        predict_edge_probability.plugin_args = plugin_args
    workflow.connect(
        [
            (
                input_spec,
                predict_edge_probability,
                [
                    ("t1_file", "t1_file"),
                    ("gm_classifier_file", "gm_classifier_file"),
                    ("wm_classifier_file", "wm_classifier_file"),
                ],
            )
        ]
    )

    if resample:
        collect_features = Node(CollectFeatureFiles(), name="CollectFeatureFiles")
        collect_features.inputs.inverse_transform = True
        workflow.connect(
            [
                (
                    input_spec,
                    collect_features,
                    [
                        ("rho", "rho"),
                        ("phi", "phi"),
                        ("theta", "theta"),
                        ("posteriors", "posterior_files"),
                        ("t1_file", "reference_file"),
                        ("acpc_transform", "transform_file"),
                    ],
                )
            ]
        )

        workflow.connect(
            [
                (
                    collect_features,
                    predict_edge_probability,
                    [("feature_files", "additional_files")],
                )
            ]
        )
    else:
        print("workflow not yet created")
        # TODO: create workflow that does not resample the input images
        return

    output_spec = Node(
        IdentityInterface(["gm_probability_map", "wm_probability_map"]),
        name="output_spec",
    )
    workflow.connect(
        predict_edge_probability,
        "gm_edge_probability",
        output_spec,
        "gm_probability_map",
    )
    workflow.connect(
        predict_edge_probability,
        "wm_edge_probability",
        output_spec,
        "wm_probability_map",
    )

    return workflow