예제 #1
0
def create_workflow(files,
                    target_file,
                    subject_id,
                    TR,
                    slice_times,
                    norm_threshold=1,
                    num_components=5,
                    vol_fwhm=None,
                    surf_fwhm=None,
                    lowpass_freq=-1,
                    highpass_freq=-1,
                    subjects_dir=None,
                    sink_directory=os.getcwd(),
                    target_subject=['fsaverage3', 'fsaverage4'],
                    name='resting'):

    wf = Workflow(name=name)

    # Rename files in case they are named identically
    name_unique = MapNode(Rename(format_string='rest_%(run)02d'),
                          iterfield=['in_file', 'run'],
                          name='rename')
    name_unique.inputs.keep_ext = True
    name_unique.inputs.run = list(range(1, len(files) + 1))
    name_unique.inputs.in_file = files

    realign = Node(nipy.SpaceTimeRealigner(), name="spacetime_realign")
    realign.inputs.slice_times = slice_times
    realign.inputs.tr = TR
    realign.inputs.slice_info = 2
    realign.plugin_args = {'sbatch_args': '-c%d' % 4}

    # Compute TSNR on realigned data regressing polynomials up to order 2
    tsnr = MapNode(TSNR(regress_poly=2), iterfield=['in_file'], name='tsnr')
    wf.connect(realign, "out_file", tsnr, "in_file")

    # Compute the median image across runs
    calc_median = Node(CalculateMedian(), name='median')
    wf.connect(tsnr, 'detrended_file', calc_median, 'in_files')

    """Segment and Register
    """

    registration = create_reg_workflow(name='registration')
    wf.connect(calc_median, 'median_file', registration, 'inputspec.mean_image')
    registration.inputs.inputspec.subject_id = subject_id
    registration.inputs.inputspec.subjects_dir = subjects_dir
    registration.inputs.inputspec.target_image = target_file

    """Quantify TSNR in each freesurfer ROI
    """

    get_roi_tsnr = MapNode(fs.SegStats(default_color_table=True),
                           iterfield=['in_file'], name='get_aparc_tsnr')
    get_roi_tsnr.inputs.avgwf_txt_file = True
    wf.connect(tsnr, 'tsnr_file', get_roi_tsnr, 'in_file')
    wf.connect(registration, 'outputspec.aparc', get_roi_tsnr, 'segmentation_file')

    """Use :class:`nipype.algorithms.rapidart` to determine which of the
    images in the functional series are outliers based on deviations in
    intensity or movement.
    """

    art = Node(interface=ArtifactDetect(), name="art")
    art.inputs.use_differences = [True, True]
    art.inputs.use_norm = True
    art.inputs.norm_threshold = norm_threshold
    art.inputs.zintensity_threshold = 9
    art.inputs.mask_type = 'spm_global'
    art.inputs.parameter_source = 'NiPy'

    """Here we are connecting all the nodes together. Notice that we add the merge node only if you choose
    to use 4D. Also `get_vox_dims` function is passed along the input volume of normalise to set the optimal
    voxel sizes.
    """

    wf.connect([(name_unique, realign, [('out_file', 'in_file')]),
                (realign, art, [('out_file', 'realigned_files')]),
                (realign, art, [('par_file', 'realignment_parameters')]),
                ])

    def selectindex(files, idx):
        import numpy as np
        from nipype.utils.filemanip import filename_to_list, list_to_filename
        return list_to_filename(np.array(filename_to_list(files))[idx].tolist())

    mask = Node(fsl.BET(), name='getmask')
    mask.inputs.mask = True
    wf.connect(calc_median, 'median_file', mask, 'in_file')
    # get segmentation in normalized functional space

    def merge_files(in1, in2):
        out_files = filename_to_list(in1)
        out_files.extend(filename_to_list(in2))
        return out_files

    # filter some noise

    # Compute motion regressors
    motreg = Node(Function(input_names=['motion_params', 'order',
                                        'derivatives'],
                           output_names=['out_files'],
                           function=motion_regressors,
                           imports=imports),
                  name='getmotionregress')
    wf.connect(realign, 'par_file', motreg, 'motion_params')

    # Create a filter to remove motion and art confounds
    createfilter1 = Node(Function(input_names=['motion_params', 'comp_norm',
                                               'outliers', 'detrend_poly'],
                                  output_names=['out_files'],
                                  function=build_filter1,
                                  imports=imports),
                         name='makemotionbasedfilter')
    createfilter1.inputs.detrend_poly = 2
    wf.connect(motreg, 'out_files', createfilter1, 'motion_params')
    wf.connect(art, 'norm_files', createfilter1, 'comp_norm')
    wf.connect(art, 'outlier_files', createfilter1, 'outliers')

    filter1 = MapNode(fsl.GLM(out_f_name='F_mcart.nii.gz',
                              out_pf_name='pF_mcart.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design', 'out_res_name'],
                      name='filtermotion')

    wf.connect(realign, 'out_file', filter1, 'in_file')
    wf.connect(realign, ('out_file', rename, '_filtermotart'),
               filter1, 'out_res_name')
    wf.connect(createfilter1, 'out_files', filter1, 'design')

    createfilter2 = MapNode(ACompCor(),
                            iterfield=['realigned_file', 'extra_regressors'],
                            name='makecompcorrfilter')
    createfilter2.inputs.components_file = 'noise_components.txt'
    createfilter2.inputs.num_components = num_components

    wf.connect(createfilter1, 'out_files', createfilter2, 'extra_regressors')
    wf.connect(filter1, 'out_res', createfilter2, 'realigned_file')
    wf.connect(registration, ('outputspec.segmentation_files', selectindex, [0, 2]),
               createfilter2, 'mask_file')

    filter2 = MapNode(fsl.GLM(out_f_name='F.nii.gz',
                              out_pf_name='pF.nii.gz',
                              demean=True),
                      iterfield=['in_file', 'design', 'out_res_name'],
                      name='filter_noise_nosmooth')
    wf.connect(filter1, 'out_res', filter2, 'in_file')
    wf.connect(filter1, ('out_res', rename, '_cleaned'),
               filter2, 'out_res_name')
    wf.connect(createfilter2, 'components_file', filter2, 'design')
    wf.connect(mask, 'mask_file', filter2, 'mask')

    bandpass = Node(Function(input_names=['files', 'lowpass_freq',
                                          'highpass_freq', 'fs'],
                             output_names=['out_files'],
                             function=bandpass_filter,
                             imports=imports),
                    name='bandpass_unsmooth')
    bandpass.inputs.fs = 1. / TR
    bandpass.inputs.highpass_freq = highpass_freq
    bandpass.inputs.lowpass_freq = lowpass_freq
    wf.connect(filter2, 'out_res', bandpass, 'files')

    """Smooth the functional data using
    :class:`nipype.interfaces.fsl.IsotropicSmooth`.
    """

    smooth = MapNode(interface=fsl.IsotropicSmooth(), name="smooth", iterfield=["in_file"])
    smooth.inputs.fwhm = vol_fwhm

    wf.connect(bandpass, 'out_files', smooth, 'in_file')

    collector = Node(Merge(2), name='collect_streams')
    wf.connect(smooth, 'out_file', collector, 'in1')
    wf.connect(bandpass, 'out_files', collector, 'in2')

    """
    Transform the remaining images. First to anatomical and then to target
    """

    warpall = MapNode(ants.ApplyTransforms(), iterfield=['input_image'],
                      name='warpall')
    warpall.inputs.input_image_type = 3
    warpall.inputs.interpolation = 'Linear'
    warpall.inputs.invert_transform_flags = [False, False]
    warpall.terminal_output = 'file'
    warpall.inputs.reference_image = target_file
    warpall.inputs.args = '--float'
    warpall.inputs.num_threads = 2
    warpall.plugin_args = {'sbatch_args': '-c%d' % 2}

    # transform to target
    wf.connect(collector, 'out', warpall, 'input_image')
    wf.connect(registration, 'outputspec.transforms', warpall, 'transforms')

    mask_target = Node(fsl.ImageMaths(op_string='-bin'), name='target_mask')

    wf.connect(registration, 'outputspec.anat2target', mask_target, 'in_file')

    maskts = MapNode(fsl.ApplyMask(), iterfield=['in_file'], name='ts_masker')
    wf.connect(warpall, 'output_image', maskts, 'in_file')
    wf.connect(mask_target, 'out_file', maskts, 'mask_file')

    # map to surface
    # extract aparc+aseg ROIs
    # extract subcortical ROIs
    # extract target space ROIs
    # combine subcortical and cortical rois into a single cifti file

    #######
    # Convert aparc to subject functional space

    # Sample the average time series in aparc ROIs
    sampleaparc = MapNode(freesurfer.SegStats(default_color_table=True),
                          iterfield=['in_file', 'summary_file',
                                     'avgwf_txt_file'],
                          name='aparc_ts')
    sampleaparc.inputs.segment_id = ([8] + list(range(10, 14)) + [17, 18, 26, 47] +
                                     list(range(49, 55)) + [58] + list(range(1001, 1036)) +
                                     list(range(2001, 2036)))

    wf.connect(registration, 'outputspec.aparc',
               sampleaparc, 'segmentation_file')
    wf.connect(collector, 'out', sampleaparc, 'in_file')

    def get_names(files, suffix):
        """Generate appropriate names for output files
        """
        from nipype.utils.filemanip import (split_filename, filename_to_list,
                                            list_to_filename)
        import os
        out_names = []
        for filename in files:
            path, name, _ = split_filename(filename)
            out_names.append(os.path.join(path, name + suffix))
        return list_to_filename(out_names)

    wf.connect(collector, ('out', get_names, '_avgwf.txt'),
               sampleaparc, 'avgwf_txt_file')
    wf.connect(collector, ('out', get_names, '_summary.stats'),
               sampleaparc, 'summary_file')

    # Sample the time series onto the surface of the target surface. Performs
    # sampling into left and right hemisphere
    target = Node(IdentityInterface(fields=['target_subject']), name='target')
    target.iterables = ('target_subject', filename_to_list(target_subject))

    samplerlh = MapNode(freesurfer.SampleToSurface(),
                        iterfield=['source_file'],
                        name='sampler_lh')
    samplerlh.inputs.sampling_method = "average"
    samplerlh.inputs.sampling_range = (0.1, 0.9, 0.1)
    samplerlh.inputs.sampling_units = "frac"
    samplerlh.inputs.interp_method = "trilinear"
    samplerlh.inputs.smooth_surf = surf_fwhm
    # samplerlh.inputs.cortex_mask = True
    samplerlh.inputs.out_type = 'niigz'
    samplerlh.inputs.subjects_dir = subjects_dir

    samplerrh = samplerlh.clone('sampler_rh')

    samplerlh.inputs.hemi = 'lh'
    wf.connect(collector, 'out', samplerlh, 'source_file')
    wf.connect(registration, 'outputspec.out_reg_file', samplerlh, 'reg_file')
    wf.connect(target, 'target_subject', samplerlh, 'target_subject')

    samplerrh.set_input('hemi', 'rh')
    wf.connect(collector, 'out', samplerrh, 'source_file')
    wf.connect(registration, 'outputspec.out_reg_file', samplerrh, 'reg_file')
    wf.connect(target, 'target_subject', samplerrh, 'target_subject')

    # Combine left and right hemisphere to text file
    combiner = MapNode(Function(input_names=['left', 'right'],
                                output_names=['out_file'],
                                function=combine_hemi,
                                imports=imports),
                       iterfield=['left', 'right'],
                       name="combiner")
    wf.connect(samplerlh, 'out_file', combiner, 'left')
    wf.connect(samplerrh, 'out_file', combiner, 'right')

    # Sample the time series file for each subcortical roi
    ts2txt = MapNode(Function(input_names=['timeseries_file', 'label_file',
                                           'indices'],
                              output_names=['out_file'],
                              function=extract_subrois,
                              imports=imports),
                     iterfield=['timeseries_file'],
                     name='getsubcortts')
    ts2txt.inputs.indices = [8] + list(range(10, 14)) + [17, 18, 26, 47] +\
        list(range(49, 55)) + [58]
    ts2txt.inputs.label_file = \
        os.path.abspath(('OASIS-TRT-20_jointfusion_DKT31_CMA_labels_in_MNI152_'
                         '2mm_v2.nii.gz'))
    wf.connect(maskts, 'out_file', ts2txt, 'timeseries_file')

    ######

    substitutions = [('_target_subject_', ''),
                     ('_filtermotart_cleaned_bp_trans_masked', ''),
                     ('_filtermotart_cleaned_bp', ''),
                     ]
    substitutions += [("_smooth%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_ts_masker%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_getsubcortts%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_combiner%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_filtermotion%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_filter_noise_nosmooth%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_makecompcorfilter%d" % i, "") for i in range(11)[::-1]]
    substitutions += [("_get_aparc_tsnr%d/" % i, "run%d_" % (i + 1)) for i in range(11)[::-1]]

    substitutions += [("T1_out_brain_pve_0_maths_warped", "compcor_csf"),
                      ("T1_out_brain_pve_1_maths_warped", "compcor_gm"),
                      ("T1_out_brain_pve_2_maths_warped", "compcor_wm"),
                      ("output_warped_image_maths", "target_brain_mask"),
                      ("median_brain_mask", "native_brain_mask"),
                      ("corr_", "")]

    regex_subs = [('_combiner.*/sar', '/smooth/'),
                  ('_combiner.*/ar', '/unsmooth/'),
                  ('_aparc_ts.*/sar', '/smooth/'),
                  ('_aparc_ts.*/ar', '/unsmooth/'),
                  ('_getsubcortts.*/sar', '/smooth/'),
                  ('_getsubcortts.*/ar', '/unsmooth/'),
                  ('series/sar', 'series/smooth/'),
                  ('series/ar', 'series/unsmooth/'),
                  ('_inverse_transform./', ''),
                  ]
    # Save the relevant data into an output directory
    datasink = Node(interface=DataSink(), name="datasink")
    datasink.inputs.base_directory = sink_directory
    datasink.inputs.container = subject_id
    datasink.inputs.substitutions = substitutions
    datasink.inputs.regexp_substitutions = regex_subs  # (r'(/_.*(\d+/))', r'/run\2')
    wf.connect(realign, 'par_file', datasink, 'resting.qa.motion')
    wf.connect(art, 'norm_files', datasink, 'resting.qa.art.@norm')
    wf.connect(art, 'intensity_files', datasink, 'resting.qa.art.@intensity')
    wf.connect(art, 'outlier_files', datasink, 'resting.qa.art.@outlier_files')
    wf.connect(registration, 'outputspec.segmentation_files', datasink, 'resting.mask_files')
    wf.connect(registration, 'outputspec.anat2target', datasink, 'resting.qa.ants')
    wf.connect(mask, 'mask_file', datasink, 'resting.mask_files.@brainmask')
    wf.connect(mask_target, 'out_file', datasink, 'resting.mask_files.target')
    wf.connect(filter1, 'out_f', datasink, 'resting.qa.compmaps.@mc_F')
    wf.connect(filter1, 'out_pf', datasink, 'resting.qa.compmaps.@mc_pF')
    wf.connect(filter2, 'out_f', datasink, 'resting.qa.compmaps')
    wf.connect(filter2, 'out_pf', datasink, 'resting.qa.compmaps.@p')
    wf.connect(registration, 'outputspec.min_cost_file', datasink, 'resting.qa.mincost')
    wf.connect(tsnr, 'tsnr_file', datasink, 'resting.qa.tsnr.@map')
    wf.connect([(get_roi_tsnr, datasink, [('avgwf_txt_file', 'resting.qa.tsnr'),
                                          ('summary_file', 'resting.qa.tsnr.@summary')])])

    wf.connect(bandpass, 'out_files', datasink, 'resting.timeseries.@bandpassed')
    wf.connect(smooth, 'out_file', datasink, 'resting.timeseries.@smoothed')
    wf.connect(createfilter1, 'out_files',
               datasink, 'resting.regress.@regressors')
    wf.connect(createfilter2, 'components_file',
               datasink, 'resting.regress.@compcorr')
    wf.connect(maskts, 'out_file', datasink, 'resting.timeseries.target')
    wf.connect(sampleaparc, 'summary_file',
               datasink, 'resting.parcellations.aparc')
    wf.connect(sampleaparc, 'avgwf_txt_file',
               datasink, 'resting.parcellations.aparc.@avgwf')
    wf.connect(ts2txt, 'out_file',
               datasink, 'resting.parcellations.grayo.@subcortical')

    datasink2 = Node(interface=DataSink(), name="datasink2")
    datasink2.inputs.base_directory = sink_directory
    datasink2.inputs.container = subject_id
    datasink2.inputs.substitutions = substitutions
    datasink2.inputs.regexp_substitutions = regex_subs  # (r'(/_.*(\d+/))', r'/run\2')
    wf.connect(combiner, 'out_file',
               datasink2, 'resting.parcellations.grayo.@surface')
    return wf
예제 #2
0
def analyze_openfmri_dataset(data_dir,
                             subject=None,
                             model_id=None,
                             task_id=None,
                             output_dir=None,
                             subj_prefix='*',
                             hpcutoff=120.,
                             use_derivatives=True,
                             fwhm=6.0,
                             subjects_dir=None,
                             target=None):
    """Analyzes an open fmri dataset

    Parameters
    ----------

    data_dir : str
        Path to the base data directory

    work_dir : str
        Nipype working directory (defaults to cwd)
    """
    """
    Load nipype workflows
    """

    preproc = create_featreg_preproc(whichvol='first')
    modelfit = create_modelfit_workflow()
    fixed_fx = create_fixed_effects_flow()
    if subjects_dir:
        registration = create_fs_reg_workflow()
    else:
        registration = create_reg_workflow()
    """
    Remove the plotting connection so that plot iterables don't propagate
    to the model stage
    """

    preproc.disconnect(preproc.get_node('plot_motion'), 'out_file',
                       preproc.get_node('outputspec'), 'motion_plots')
    """
    Set up openfmri data specific components
    """

    subjects = sorted([
        path.split(os.path.sep)[-1]
        for path in glob(os.path.join(data_dir, subj_prefix))
    ])

    infosource = pe.Node(
        niu.IdentityInterface(fields=['subject_id', 'model_id', 'task_id']),
        name='infosource')
    if len(subject) == 0:
        infosource.iterables = [('subject_id', subjects),
                                ('model_id', [model_id]), ('task_id', task_id)]
    else:
        infosource.iterables = [
            ('subject_id',
             [subjects[subjects.index(subj)] for subj in subject]),
            ('model_id', [model_id]), ('task_id', task_id)
        ]

    subjinfo = pe.Node(niu.Function(
        input_names=['subject_id', 'base_dir', 'task_id', 'model_id'],
        output_names=['run_id', 'conds', 'TR'],
        function=get_subjectinfo),
                       name='subjectinfo')
    subjinfo.inputs.base_dir = data_dir
    """
    Return data components as anat, bold and behav
    """

    contrast_file = os.path.join(data_dir, 'models', 'model%03d' % model_id,
                                 'task_contrasts.txt')
    has_contrast = os.path.exists(contrast_file)
    if has_contrast:
        datasource = pe.Node(nio.DataGrabber(
            infields=['subject_id', 'run_id', 'task_id', 'model_id'],
            outfields=['anat', 'bold', 'behav', 'contrasts']),
                             name='datasource')
    else:
        datasource = pe.Node(nio.DataGrabber(
            infields=['subject_id', 'run_id', 'task_id', 'model_id'],
            outfields=['anat', 'bold', 'behav']),
                             name='datasource')
    datasource.inputs.base_directory = data_dir
    datasource.inputs.template = '*'

    if has_contrast:
        datasource.inputs.field_template = {
            'anat': '%s/anatomy/T1_001.nii.gz',
            'bold': '%s/BOLD/task%03d_r*/bold.nii.gz',
            'behav': ('%s/model/model%03d/onsets/task%03d_'
                      'run%03d/cond*.txt'),
            'contrasts': ('models/model%03d/'
                          'task_contrasts.txt')
        }
        datasource.inputs.template_args = {
            'anat': [['subject_id']],
            'bold': [['subject_id', 'task_id']],
            'behav': [['subject_id', 'model_id', 'task_id', 'run_id']],
            'contrasts': [['model_id']]
        }
    else:
        datasource.inputs.field_template = {
            'anat': '%s/anatomy/T1_001.nii.gz',
            'bold': '%s/BOLD/task%03d_r*/bold.nii.gz',
            'behav': ('%s/model/model%03d/onsets/task%03d_'
                      'run%03d/cond*.txt')
        }
        datasource.inputs.template_args = {
            'anat': [['subject_id']],
            'bold': [['subject_id', 'task_id']],
            'behav': [['subject_id', 'model_id', 'task_id', 'run_id']]
        }

    datasource.inputs.sort_filelist = True
    """
    Create meta workflow
    """

    wf = pe.Workflow(name='openfmri')
    wf.connect(infosource, 'subject_id', subjinfo, 'subject_id')
    wf.connect(infosource, 'model_id', subjinfo, 'model_id')
    wf.connect(infosource, 'task_id', subjinfo, 'task_id')
    wf.connect(infosource, 'subject_id', datasource, 'subject_id')
    wf.connect(infosource, 'model_id', datasource, 'model_id')
    wf.connect(infosource, 'task_id', datasource, 'task_id')
    wf.connect(subjinfo, 'run_id', datasource, 'run_id')
    wf.connect([
        (datasource, preproc, [('bold', 'inputspec.func')]),
    ])

    def get_highpass(TR, hpcutoff):
        return hpcutoff / (2. * TR)

    gethighpass = pe.Node(niu.Function(input_names=['TR', 'hpcutoff'],
                                       output_names=['highpass'],
                                       function=get_highpass),
                          name='gethighpass')
    wf.connect(subjinfo, 'TR', gethighpass, 'TR')
    wf.connect(gethighpass, 'highpass', preproc, 'inputspec.highpass')
    """
    Setup a basic set of contrasts, a t-test per condition
    """

    def get_contrasts(contrast_file, task_id, conds):
        import numpy as np
        import os
        contrast_def = []
        if os.path.exists(contrast_file):
            with open(contrast_file, 'rt') as fp:
                contrast_def.extend([
                    np.array(row.split()) for row in fp.readlines()
                    if row.strip()
                ])
        contrasts = []
        for row in contrast_def:
            if row[0] != 'task%03d' % task_id:
                continue
            con = [
                row[1], 'T', ['cond%03d' % (i + 1) for i in range(len(conds))],
                row[2:].astype(float).tolist()
            ]
            contrasts.append(con)
        # add auto contrasts for each column
        for i, cond in enumerate(conds):
            con = [cond, 'T', ['cond%03d' % (i + 1)], [1]]
            contrasts.append(con)
        return contrasts

    contrastgen = pe.Node(niu.Function(
        input_names=['contrast_file', 'task_id', 'conds'],
        output_names=['contrasts'],
        function=get_contrasts),
                          name='contrastgen')

    art = pe.MapNode(
        interface=ra.ArtifactDetect(use_differences=[True, False],
                                    use_norm=True,
                                    norm_threshold=1,
                                    zintensity_threshold=3,
                                    parameter_source='FSL',
                                    mask_type='file'),
        iterfield=['realigned_files', 'realignment_parameters', 'mask_file'],
        name="art")

    modelspec = pe.Node(interface=model.SpecifyModel(), name="modelspec")
    modelspec.inputs.input_units = 'secs'

    def check_behav_list(behav, run_id, conds):
        import numpy as np
        num_conds = len(conds)
        if isinstance(behav, (str, bytes)):
            behav = [behav]
        behav_array = np.array(behav).flatten()
        num_elements = behav_array.shape[0]
        return behav_array.reshape(int(num_elements / num_conds),
                                   num_conds).tolist()

    reshape_behav = pe.Node(niu.Function(
        input_names=['behav', 'run_id', 'conds'],
        output_names=['behav'],
        function=check_behav_list),
                            name='reshape_behav')

    wf.connect(subjinfo, 'TR', modelspec, 'time_repetition')
    wf.connect(datasource, 'behav', reshape_behav, 'behav')
    wf.connect(subjinfo, 'run_id', reshape_behav, 'run_id')
    wf.connect(subjinfo, 'conds', reshape_behav, 'conds')
    wf.connect(reshape_behav, 'behav', modelspec, 'event_files')

    wf.connect(subjinfo, 'TR', modelfit, 'inputspec.interscan_interval')
    wf.connect(subjinfo, 'conds', contrastgen, 'conds')
    if has_contrast:
        wf.connect(datasource, 'contrasts', contrastgen, 'contrast_file')
    else:
        contrastgen.inputs.contrast_file = ''
    wf.connect(infosource, 'task_id', contrastgen, 'task_id')
    wf.connect(contrastgen, 'contrasts', modelfit, 'inputspec.contrasts')

    wf.connect([(preproc, art,
                 [('outputspec.motion_parameters', 'realignment_parameters'),
                  ('outputspec.realigned_files', 'realigned_files'),
                  ('outputspec.mask', 'mask_file')]),
                (preproc, modelspec,
                 [('outputspec.highpassed_files', 'functional_runs'),
                  ('outputspec.motion_parameters', 'realignment_parameters')]),
                (art, modelspec, [('outlier_files', 'outlier_files')]),
                (modelspec, modelfit, [('session_info',
                                        'inputspec.session_info')]),
                (preproc, modelfit, [('outputspec.highpassed_files',
                                      'inputspec.functional_data')])])

    # Comute TSNR on realigned data regressing polynomials upto order 2
    tsnr = MapNode(TSNR(regress_poly=2), iterfield=['in_file'], name='tsnr')
    wf.connect(preproc, "outputspec.realigned_files", tsnr, "in_file")

    # Compute the median image across runs
    calc_median = Node(CalculateMedian(), name='median')
    wf.connect(tsnr, 'detrended_file', calc_median, 'in_files')
    """
    Reorder the copes so that now it combines across runs
    """

    def sort_copes(copes, varcopes, contrasts):
        import numpy as np
        if not isinstance(copes, list):
            copes = [copes]
            varcopes = [varcopes]
        num_copes = len(contrasts)
        n_runs = len(copes)
        all_copes = np.array(copes).flatten()
        all_varcopes = np.array(varcopes).flatten()
        outcopes = all_copes.reshape(int(len(all_copes) / num_copes),
                                     num_copes).T.tolist()
        outvarcopes = all_varcopes.reshape(int(len(all_varcopes) / num_copes),
                                           num_copes).T.tolist()
        return outcopes, outvarcopes, n_runs

    cope_sorter = pe.Node(niu.Function(
        input_names=['copes', 'varcopes', 'contrasts'],
        output_names=['copes', 'varcopes', 'n_runs'],
        function=sort_copes),
                          name='cope_sorter')

    pickfirst = lambda x: x[0]

    wf.connect(contrastgen, 'contrasts', cope_sorter, 'contrasts')
    wf.connect([(preproc, fixed_fx, [(('outputspec.mask', pickfirst),
                                      'flameo.mask_file')]),
                (modelfit, cope_sorter, [('outputspec.copes', 'copes')]),
                (modelfit, cope_sorter, [('outputspec.varcopes', 'varcopes')]),
                (cope_sorter, fixed_fx, [('copes', 'inputspec.copes'),
                                         ('varcopes', 'inputspec.varcopes'),
                                         ('n_runs', 'l2model.num_copes')]),
                (modelfit, fixed_fx, [
                    ('outputspec.dof_file', 'inputspec.dof_files'),
                ])])

    wf.connect(calc_median, 'median_file', registration,
               'inputspec.mean_image')
    if subjects_dir:
        wf.connect(infosource, 'subject_id', registration,
                   'inputspec.subject_id')
        registration.inputs.inputspec.subjects_dir = subjects_dir
        registration.inputs.inputspec.target_image = fsl.Info.standard_image(
            'MNI152_T1_2mm_brain.nii.gz')
        if target:
            registration.inputs.inputspec.target_image = target
    else:
        wf.connect(datasource, 'anat', registration,
                   'inputspec.anatomical_image')
        registration.inputs.inputspec.target_image = fsl.Info.standard_image(
            'MNI152_T1_2mm.nii.gz')
        registration.inputs.inputspec.target_image_brain = fsl.Info.standard_image(
            'MNI152_T1_2mm_brain.nii.gz')
        registration.inputs.inputspec.config_file = 'T1_2_MNI152_2mm'

    def merge_files(copes, varcopes, zstats):
        out_files = []
        splits = []
        out_files.extend(copes)
        splits.append(len(copes))
        out_files.extend(varcopes)
        splits.append(len(varcopes))
        out_files.extend(zstats)
        splits.append(len(zstats))
        return out_files, splits

    mergefunc = pe.Node(niu.Function(
        input_names=['copes', 'varcopes', 'zstats'],
        output_names=['out_files', 'splits'],
        function=merge_files),
                        name='merge_files')
    wf.connect([(fixed_fx.get_node('outputspec'), mergefunc, [
        ('copes', 'copes'),
        ('varcopes', 'varcopes'),
        ('zstats', 'zstats'),
    ])])
    wf.connect(mergefunc, 'out_files', registration, 'inputspec.source_files')

    def split_files(in_files, splits):
        copes = in_files[:splits[0]]
        varcopes = in_files[splits[0]:(splits[0] + splits[1])]
        zstats = in_files[(splits[0] + splits[1]):]
        return copes, varcopes, zstats

    splitfunc = pe.Node(niu.Function(
        input_names=['in_files', 'splits'],
        output_names=['copes', 'varcopes', 'zstats'],
        function=split_files),
                        name='split_files')
    wf.connect(mergefunc, 'splits', splitfunc, 'splits')
    wf.connect(registration, 'outputspec.transformed_files', splitfunc,
               'in_files')

    if subjects_dir:
        get_roi_mean = pe.MapNode(fs.SegStats(default_color_table=True),
                                  iterfield=['in_file'],
                                  name='get_aparc_means')
        get_roi_mean.inputs.avgwf_txt_file = True
        wf.connect(fixed_fx.get_node('outputspec'), 'copes', get_roi_mean,
                   'in_file')
        wf.connect(registration, 'outputspec.aparc', get_roi_mean,
                   'segmentation_file')

        get_roi_tsnr = pe.MapNode(fs.SegStats(default_color_table=True),
                                  iterfield=['in_file'],
                                  name='get_aparc_tsnr')
        get_roi_tsnr.inputs.avgwf_txt_file = True
        wf.connect(tsnr, 'tsnr_file', get_roi_tsnr, 'in_file')
        wf.connect(registration, 'outputspec.aparc', get_roi_tsnr,
                   'segmentation_file')
    """
    Connect to a datasink
    """

    def get_subs(subject_id, conds, run_id, model_id, task_id):
        subs = [('_subject_id_%s_' % subject_id, '')]
        subs.append(('_model_id_%d' % model_id, 'model%03d' % model_id))
        subs.append(('task_id_%d/' % task_id, '/task%03d_' % task_id))
        subs.append(
            ('bold_dtype_mcf_mask_smooth_mask_gms_tempfilt_mean_warp', 'mean'))
        subs.append(('bold_dtype_mcf_mask_smooth_mask_gms_tempfilt_mean_flirt',
                     'affine'))

        for i in range(len(conds)):
            subs.append(('_flameo%d/cope1.' % i, 'cope%02d.' % (i + 1)))
            subs.append(('_flameo%d/varcope1.' % i, 'varcope%02d.' % (i + 1)))
            subs.append(('_flameo%d/zstat1.' % i, 'zstat%02d.' % (i + 1)))
            subs.append(('_flameo%d/tstat1.' % i, 'tstat%02d.' % (i + 1)))
            subs.append(('_flameo%d/res4d.' % i, 'res4d%02d.' % (i + 1)))
            subs.append(('_warpall%d/cope1_warp.' % i, 'cope%02d.' % (i + 1)))
            subs.append(('_warpall%d/varcope1_warp.' % (len(conds) + i),
                         'varcope%02d.' % (i + 1)))
            subs.append(('_warpall%d/zstat1_warp.' % (2 * len(conds) + i),
                         'zstat%02d.' % (i + 1)))
            subs.append(('_warpall%d/cope1_trans.' % i, 'cope%02d.' % (i + 1)))
            subs.append(('_warpall%d/varcope1_trans.' % (len(conds) + i),
                         'varcope%02d.' % (i + 1)))
            subs.append(('_warpall%d/zstat1_trans.' % (2 * len(conds) + i),
                         'zstat%02d.' % (i + 1)))
            subs.append(('__get_aparc_means%d/' % i, '/cope%02d_' % (i + 1)))

        for i, run_num in enumerate(run_id):
            subs.append(('__get_aparc_tsnr%d/' % i, '/run%02d_' % run_num))
            subs.append(('__art%d/' % i, '/run%02d_' % run_num))
            subs.append(('__dilatemask%d/' % i, '/run%02d_' % run_num))
            subs.append(('__realign%d/' % i, '/run%02d_' % run_num))
            subs.append(('__modelgen%d/' % i, '/run%02d_' % run_num))
        subs.append(('/model%03d/task%03d/' % (model_id, task_id), '/'))
        subs.append(('/model%03d/task%03d_' % (model_id, task_id), '/'))
        subs.append(('_bold_dtype_mcf_bet_thresh_dil', '_mask'))
        subs.append(('_output_warped_image', '_anat2target'))
        subs.append(('median_flirt_brain_mask', 'median_brain_mask'))
        subs.append(('median_bbreg_brain_mask', 'median_brain_mask'))
        return subs

    subsgen = pe.Node(niu.Function(
        input_names=['subject_id', 'conds', 'run_id', 'model_id', 'task_id'],
        output_names=['substitutions'],
        function=get_subs),
                      name='subsgen')
    wf.connect(subjinfo, 'run_id', subsgen, 'run_id')

    datasink = pe.Node(interface=nio.DataSink(), name="datasink")
    wf.connect(infosource, 'subject_id', datasink, 'container')
    wf.connect(infosource, 'subject_id', subsgen, 'subject_id')
    wf.connect(infosource, 'model_id', subsgen, 'model_id')
    wf.connect(infosource, 'task_id', subsgen, 'task_id')
    wf.connect(contrastgen, 'contrasts', subsgen, 'conds')
    wf.connect(subsgen, 'substitutions', datasink, 'substitutions')
    wf.connect([(fixed_fx.get_node('outputspec'), datasink,
                 [('res4d', 'res4d'), ('copes', 'copes'),
                  ('varcopes', 'varcopes'), ('zstats', 'zstats'),
                  ('tstats', 'tstats')])])
    wf.connect([(modelfit.get_node('modelgen'), datasink, [
        ('design_cov', 'qa.model'),
        ('design_image', 'qa.model.@matrix_image'),
        ('design_file', 'qa.model.@matrix'),
    ])])
    wf.connect([(preproc, datasink,
                 [('outputspec.motion_parameters', 'qa.motion'),
                  ('outputspec.motion_plots', 'qa.motion.plots'),
                  ('outputspec.mask', 'qa.mask')])])
    wf.connect(registration, 'outputspec.mean2anat_mask', datasink,
               'qa.mask.mean2anat')
    wf.connect(art, 'norm_files', datasink, 'qa.art.@norm')
    wf.connect(art, 'intensity_files', datasink, 'qa.art.@intensity')
    wf.connect(art, 'outlier_files', datasink, 'qa.art.@outlier_files')
    wf.connect(registration, 'outputspec.anat2target', datasink,
               'qa.anat2target')
    wf.connect(tsnr, 'tsnr_file', datasink, 'qa.tsnr.@map')
    if subjects_dir:
        wf.connect(registration, 'outputspec.min_cost_file', datasink,
                   'qa.mincost')
        wf.connect([(get_roi_tsnr, datasink, [('avgwf_txt_file', 'qa.tsnr'),
                                              ('summary_file',
                                               'qa.tsnr.@summary')])])
        wf.connect([(get_roi_mean, datasink, [('avgwf_txt_file', 'copes.roi'),
                                              ('summary_file',
                                               'copes.roi.@summary')])])
    wf.connect([(splitfunc, datasink, [
        ('copes', 'copes.mni'),
        ('varcopes', 'varcopes.mni'),
        ('zstats', 'zstats.mni'),
    ])])
    wf.connect(calc_median, 'median_file', datasink, 'mean')
    wf.connect(registration, 'outputspec.transformed_mean', datasink,
               'mean.mni')
    wf.connect(registration, 'outputspec.func2anat_transform', datasink,
               'xfm.mean2anat')
    wf.connect(registration, 'outputspec.anat2target_transform', datasink,
               'xfm.anat2target')
    """
    Set processing parameters
    """

    preproc.inputs.inputspec.fwhm = fwhm
    gethighpass.inputs.hpcutoff = hpcutoff
    modelspec.inputs.high_pass_filter_cutoff = hpcutoff
    modelfit.inputs.inputspec.bases = {'dgamma': {'derivs': use_derivatives}}
    modelfit.inputs.inputspec.model_serial_correlations = True
    modelfit.inputs.inputspec.film_threshold = 1000

    datasink.inputs.base_directory = output_dir
    return wf