Exemplo n.º 1
0
def test_MRISPreproc_outputs():
    output_map = dict(out_file=dict(), )
    outputs = MRISPreproc.output_spec()

    for key, metadata in output_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(outputs.traits()[key], metakey), value
Exemplo n.º 2
0
def test_MRISPreproc_outputs():
    output_map = dict(out_file=dict(),
    )
    outputs = MRISPreproc.output_spec()

    for key, metadata in output_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(outputs.traits()[key], metakey), value
Exemplo n.º 3
0
def test_MRISPreproc_inputs():
    input_map = dict(surf_dir=dict(argstr='--surfdir %s',
    ),
    vol_measure_file=dict(argstr='--iv %s %s...',
    ),
    fsgd_file=dict(xor=('subjects', 'fsgd_file', 'subject_file'),
    argstr='--fsgd %s',
    ),
    fwhm=dict(xor=['num_iters'],
    argstr='--fwhm %f',
    ),
    fwhm_source=dict(xor=['num_iters_source'],
    argstr='--fwhm-src %f',
    ),
    surf_measure=dict(xor=('surf_measure', 'surf_measure_file', 'surf_area'),
    argstr='--meas %s',
    ),
    subject_file=dict(xor=('subjects', 'fsgd_file', 'subject_file'),
    argstr='--f %s',
    ),
    surf_measure_file=dict(xor=('surf_measure', 'surf_measure_file', 'surf_area'),
    argstr='--is %s...',
    ),
    source_format=dict(argstr='--srcfmt %s',
    ),
    subjects=dict(xor=('subjects', 'fsgd_file', 'subject_file'),
    argstr='--s %s...',
    ),
    ignore_exception=dict(nohash=True,
    usedefault=True,
    ),
    hemi=dict(mandatory=True,
    argstr='--hemi %s',
    ),
    surf_area=dict(xor=('surf_measure', 'surf_measure_file', 'surf_area'),
    argstr='--area %s',
    ),
    args=dict(argstr='%s',
    ),
    terminal_output=dict(mandatory=True,
    nohash=True,
    ),
    num_iters_source=dict(xor=['fwhm_source'],
    argstr='--niterssrc %d',
    ),
    smooth_cortex_only=dict(argstr='--smooth-cortex-only',
    ),
    subjects_dir=dict(),
    num_iters=dict(xor=['fwhm'],
    argstr='--niters %d',
    ),
    proj_frac=dict(argstr='--projfrac %s',
    ),
    target=dict(mandatory=True,
    argstr='--target %s',
    ),
    out_file=dict(argstr='--out %s',
    genfile=True,
    ),
    environ=dict(nohash=True,
    usedefault=True,
    ),
    )
    inputs = MRISPreproc.input_spec()

    for key, metadata in input_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(inputs.traits()[key], metakey), value
Exemplo n.º 4
0
def test_MRISPreproc_inputs():
    input_map = dict(args=dict(argstr='%s',
    ),
    environ=dict(nohash=True,
    usedefault=True,
    ),
    fsgd_file=dict(argstr='--fsgd %s',
    xor=('subjects', 'fsgd_file', 'subject_file'),
    ),
    fwhm=dict(argstr='--fwhm %f',
    xor=['num_iters'],
    ),
    fwhm_source=dict(argstr='--fwhm-src %f',
    xor=['num_iters_source'],
    ),
    hemi=dict(argstr='--hemi %s',
    mandatory=True,
    ),
    ignore_exception=dict(nohash=True,
    usedefault=True,
    ),
    num_iters=dict(argstr='--niters %d',
    xor=['fwhm'],
    ),
    num_iters_source=dict(argstr='--niterssrc %d',
    xor=['fwhm_source'],
    ),
    out_file=dict(argstr='--out %s',
    genfile=True,
    ),
    proj_frac=dict(argstr='--projfrac %s',
    ),
    smooth_cortex_only=dict(argstr='--smooth-cortex-only',
    ),
    source_format=dict(argstr='--srcfmt %s',
    ),
    subject_file=dict(argstr='--f %s',
    xor=('subjects', 'fsgd_file', 'subject_file'),
    ),
    subjects=dict(argstr='--s %s...',
    xor=('subjects', 'fsgd_file', 'subject_file'),
    ),
    subjects_dir=dict(),
    surf_area=dict(argstr='--area %s',
    xor=('surf_measure', 'surf_measure_file', 'surf_area'),
    ),
    surf_dir=dict(argstr='--surfdir %s',
    ),
    surf_measure=dict(argstr='--meas %s',
    xor=('surf_measure', 'surf_measure_file', 'surf_area'),
    ),
    surf_measure_file=dict(argstr='--is %s...',
    xor=('surf_measure', 'surf_measure_file', 'surf_area'),
    ),
    target=dict(argstr='--target %s',
    mandatory=True,
    ),
    terminal_output=dict(mandatory=True,
    nohash=True,
    ),
    vol_measure_file=dict(argstr='--iv %s %s...',
    ),
    )
    inputs = MRISPreproc.input_spec()

    for key, metadata in input_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(inputs.traits()[key], metakey), value
Exemplo n.º 5
0
def genFreesurferSBMglmWF(name='fsSBM',
                          base_dir=op.abspath('.'),
                          group_sublist=[],
                          model_dir=None,
                          model_info={'Model1': 'dods'},
                          design_input='fsgd',
                          fs_subjects_dir='/data/analyses/work_in_progress/freesurfer/fsmrishare-flair6.0/',
                          fwhm=[0.0, 10.0],
                          measure_list=['thickness', 'area'],
                          target_atlas='fsaverage',
                          target_atlas_surfreg='sphere.reg',
                          correction_method='FDR'):
    
    wf = Workflow(name)
    wf.base_dir = base_dir

    
    # Node: model List
    modelList = Node(IdentityInterface(fields=['model_name'], mandatory_inputs=True),
                    name='modelList')
    modelList.iterables = ('model_name', list(model_info.keys()))

    
    # Grab fsgd or design mat and contrast mtx files from model_dir
    fileList_temp_args = {'contrast_files': [['model_name', '*.mtx']],
                          'contrast_sign_files':  [['model_name', '*.mdtx']]}
    if design_input == 'fsgd':
        fileList_temp_args['fsgd_file'] = [['model_name', '*.fsgd']]
    elif design_input == 'design_mat':
        fileList_temp_args['design_mat'] = [['model_name', 'X.mat']]
        
    fileList = Node(DataGrabber(infields=['model_name'],
                                outfields=list(fileList_temp_args.keys())), 
                    name="fileList")
    fileList.inputs.base_directory = model_dir
    fileList.inputs.ignore_exception = False
    fileList.inputs.raise_on_empty = True
    fileList.inputs.sort_filelist = True
    fileList.inputs.template = '%s/%s'
    fileList.inputs.template_args =  fileList_temp_args
    wf.connect(modelList, "model_name", fileList, "model_name")


    # preproc for each hemisphere to produce concatenated file for glmfit and 
    # also a mean map
    
    # Define a few other iterables
    measList = Node(IdentityInterface(fields=['measure'],
                                      mandatory_inputs=True),
                    name='measList')
    measList.iterables = ('measure', measure_list)
    
    smoothList = Node(IdentityInterface(fields=['fwhm'],
                                        mandatory_inputs=True),
                      name='smoothList')
    smoothList.iterables = ('fwhm', fwhm)
    
    surfaces = ['inflated', 'pial']
    plotSurfList = Node(IdentityInterface(fields=['surf']),
                      name='plotSurfList')
    plotSurfList.iterables = ('surf', surfaces)
    
    
    # MRI_preproc
    lhSBMpreproc = MapNode(MRISPreproc(),
                           name='lhSBMpreproc',
                           iterfield=['args', 'out_file'])
    lhSBMpreproc.inputs.subjects_dir = fs_subjects_dir
    lhSBMpreproc.inputs.target = target_atlas
    lhSBMpreproc.inputs.hemi = 'lh'
    lhSBMpreproc.inputs.args = ['', '--mean']
    lhSBMpreproc.inputs.out_file = ['{}.lh.{}.mgh'.format(out_name, target_atlas) for out_name in ['stacked', 'mean']]
    lhSBMpreproc.inputs.subjects = group_sublist
    wf.connect(measList, "measure", lhSBMpreproc, "surf_measure")
    
    rhSBMpreproc = MapNode(MRISPreproc(),
                           name='rhSBMpreproc',
                           iterfield=['args', 'out_file'])
    rhSBMpreproc.inputs.subjects_dir = fs_subjects_dir
    rhSBMpreproc.inputs.target = target_atlas
    rhSBMpreproc.inputs.hemi = 'rh'
    rhSBMpreproc.inputs.args = ['', '--mean']
    rhSBMpreproc.inputs.out_file = ['{}.rh.{}.mgh'.format(out_name, target_atlas) for out_name in ['stacked', 'mean']]
    rhSBMpreproc.inputs.subjects = group_sublist
    wf.connect(measList, "measure", rhSBMpreproc, "surf_measure")
    
    
    # Create smoothed mean maps for each non-zero fwhm
    non_zero_fwhm = [val for val in fwhm if val != 0.0]
    lhSmoothMean = MapNode(SurfaceSmooth(),
                           name='lhSmoothMean',
                           iterfield=['fwhm', 'out_file'])
    lhSmoothMean.inputs.subject_id = target_atlas
    lhSmoothMean.inputs.hemi = 'lh'
    lhSmoothMean.inputs.subjects_dir = fs_subjects_dir
    lhSmoothMean.inputs.fwhm = non_zero_fwhm
    lhSmoothMean.inputs.cortex = True
    lhSmoothMean.inputs.out_file = ['mean.lh.fwhm{}.{}.mgh'.format(str(int(val)), target_atlas) for val in non_zero_fwhm]
    wf.connect(lhSBMpreproc, ('out_file', getElementFromList, 1), lhSmoothMean, 'in_file')
    
    rhSmoothMean = MapNode(SurfaceSmooth(),
                           name='rhSmoothMean',
                           iterfield=['fwhm', 'out_file'])
    rhSmoothMean.inputs.subject_id = target_atlas
    rhSmoothMean.inputs.hemi = 'rh'
    rhSmoothMean.inputs.subjects_dir = fs_subjects_dir
    rhSmoothMean.inputs.fwhm = non_zero_fwhm
    rhSmoothMean.inputs.cortex = True
    rhSmoothMean.inputs.out_file = ['mean.rh.fwhm{}.{}.mgh'.format(str(int(val)), target_atlas) for val in non_zero_fwhm]
    wf.connect(rhSBMpreproc, ('out_file', getElementFromList, 1), rhSmoothMean, 'in_file')

    
    # For each concatenated surfaces produced by the SBMpreproc, run glmfit
    
    if correction_method == 'FDR':
        save_res = False
    elif correction_method == 'perm':
        save_res = True
    
    if design_input == 'fsgd': 
        fsgdInput = Node(Function(input_names=['item1', 'item2'],
                                  output_names=['out_tuple'],
                                  function=createTuple2),
                         name='fsgdInput')
        wf.connect(fileList, 'fsgd_file', fsgdInput, 'item1')
        wf.connect(modelList, ('model_name', getValFromDict, model_info),
                   fsgdInput, 'item2')
    
    lhSBMglmfit = Node(GLMFit(),
                       name='lhSBMglmfit')
    lhSBMglmfit.inputs.subjects_dir = fs_subjects_dir
    lhSBMglmfit.inputs.surf = True
    lhSBMglmfit.inputs.subject_id = target_atlas
    lhSBMglmfit.inputs.hemi = 'lh'
    lhSBMglmfit.inputs.cortex = True
    lhSBMglmfit.inputs.save_residual = save_res
    wf.connect(smoothList, 'fwhm', lhSBMglmfit, 'fwhm')
    wf.connect(lhSBMpreproc, ('out_file', getElementFromList, 0), lhSBMglmfit, 'in_file')
    if design_input == 'fsgd':
        wf.connect(fsgdInput, 'out_tuple', lhSBMglmfit, 'fsgd')
    elif design_input == 'design_mat':
        wf.connect(fileList, 'design_mat', lhSBMglmfit, 'design')
    wf.connect(fileList, 'contrast_files', lhSBMglmfit, 'contrast')
    
    rhSBMglmfit = Node(GLMFit(),
                       name='rhSBMglmfit')
    rhSBMglmfit.inputs.subjects_dir = fs_subjects_dir
    rhSBMglmfit.inputs.surf = True
    rhSBMglmfit.inputs.subject_id = target_atlas
    rhSBMglmfit.inputs.hemi = 'rh'
    rhSBMglmfit.inputs.cortex = True
    rhSBMglmfit.inputs.save_residual = save_res
    wf.connect(smoothList, 'fwhm', rhSBMglmfit, 'fwhm')
    wf.connect(rhSBMpreproc, ('out_file', getElementFromList, 0), rhSBMglmfit, 'in_file')
    if design_input == 'fsgd':
        wf.connect(fsgdInput, 'out_tuple', rhSBMglmfit, 'fsgd')
    elif design_input == 'design_mat':
        wf.connect(fileList, 'design_mat', rhSBMglmfit, 'design')
    wf.connect(fileList, 'contrast_files', rhSBMglmfit, 'contrast')


    # perfrom FDR correction if 'FDR' is chosen
    if correction_method == 'FDR':
        
        mriFDR = MapNode(FDR(),
                         iterfield=['in_file1', 'in_file2', 'fdr_sign'],
                         name='mriFDR')
        mriFDR.inputs.fdr = 0.05
        mriFDR.inputs.out_thr_file = 'fdr_threshold.txt'
        mriFDR.inputs.out_file1 = 'lh.sig_corr.mgh'
        mriFDR.inputs.out_file2 = 'rh.sig_corr.mgh'
        wf.connect(lhSBMglmfit, 'sig_file', mriFDR, 'in_file1')
        wf.connect(lhSBMglmfit, 'mask_file', mriFDR, 'in_mask1')
        wf.connect(rhSBMglmfit, 'sig_file', mriFDR, 'in_file2')
        wf.connect(rhSBMglmfit, 'mask_file', mriFDR, 'in_mask2')
        wf.connect(fileList, ('contrast_sign_files', getElementsFromTxtList),
                   mriFDR, 'fdr_sign')
        

    # perform Permutation if 'perm' is chosen
    elif correction_method == 'perm':
        
#        glmSim = MapNode(GLMFitSim(),
#                         iterfield=['glm_dir', 'permutation'],
#                         name='glmSim')
#        glmSim.inputs.spaces = '2spaces'
        
         raise NotImplementedError
     
        
    ### Plotting ###
    lh_bg_map = op.join(fs_subjects_dir, target_atlas, 'surf', 'lh.sulc')
    rh_bg_map = op.join(fs_subjects_dir, target_atlas, 'surf', 'rh.sulc')
    
    # Plot the mean map
    plotMeanMaps = MapNode(Function(input_names=['lh_surf', 'lh_surf_map', 'lh_bg_map',
                                                 'rh_surf', 'rh_surf_map', 'rh_bg_map',
                                                 'out_fname'],
                                    output_name=['out_file'],
                                    function=plot_surf_map),
                           iterfield=['lh_surf_map', 'rh_surf_map', 'out_fname'],
                           name='plotMeanMaps')
    plotMeanMaps.inputs.lh_bg_map = lh_bg_map
    plotMeanMaps.inputs.rh_bg_map = rh_bg_map
    plotMeanMaps.inputs.out_fname = ['mean_fwhm{}.png'.format(s) for s in non_zero_fwhm]
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'lh.')),
               plotMeanMaps, 'lh_surf')
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'rh.')),
               plotMeanMaps, 'rh_surf')
    wf.connect(lhSmoothMean, 'out_file', plotMeanMaps, 'lh_surf_map')
    wf.connect(rhSmoothMean, 'out_file', plotMeanMaps, 'rh_surf_map')
    
      
    # Plot uncorrected maps
    plot_stat_inputs = ['lh_surf', 'lh_stat_map', 'lh_bg_map',
                        'rh_surf', 'rh_stat_map', 'rh_bg_map',
                        'out_fname', 'cmap', 'upper_lim', 'threshold']
    
    plotUncorrectedG = MapNode(Function(input_names=plot_stat_inputs,
                                        output_name=['out_file'],
                                        function=plot_surf_stat),
                               iterfield=['lh_stat_map', 'rh_stat_map', 'out_fname'],
                               name='plotUncorrectedG')
    plotUncorrectedG.inputs.lh_bg_map = lh_bg_map
    plotUncorrectedG.inputs.rh_bg_map = rh_bg_map
    plotUncorrectedG.inputs.cmap = 'jet'
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'lh.')),
               plotUncorrectedG, 'lh_surf')
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'rh.')),
               plotUncorrectedG, 'rh_surf')
    wf.connect(fileList, ('contrast_files',  makeFStringElementFromFnameList, '.mtx', '_uncorrected_gamma_map.png', True),
               plotUncorrectedG, 'out_fname')
    wf.connect(lhSBMglmfit, 'gamma_file', plotUncorrectedG, 'lh_stat_map')
    wf.connect(rhSBMglmfit, 'gamma_file', plotUncorrectedG, 'rh_stat_map')
    
    plotUncorrectedP = MapNode(Function(input_names=plot_stat_inputs,
                                        output_name=['out_file'],
                                        function=plot_surf_stat),
                               iterfield=['lh_stat_map', 'rh_stat_map', 'out_fname'],
                               name='plotUncorrectedP')
    plotUncorrectedP.inputs.lh_bg_map = lh_bg_map
    plotUncorrectedP.inputs.rh_bg_map = rh_bg_map
    plotUncorrectedP.inputs.upper_lim = 10.0
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'lh.')),
               plotUncorrectedP, 'lh_surf')
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'rh.')),
               plotUncorrectedP, 'rh_surf')
    wf.connect(fileList, ('contrast_files',  makeFStringElementFromFnameList, '.mtx', '_uncorrected_p_map.png', True),
               plotUncorrectedP, 'out_fname')
    wf.connect(lhSBMglmfit, 'sig_file', plotUncorrectedP, 'lh_stat_map')
    wf.connect(rhSBMglmfit, 'sig_file', plotUncorrectedP, 'rh_stat_map')
    
    # Plot the corrected map
    
    # For gamma first create gamma masked by corrected p
    lhMaskGamma = MapNode(MRIsCalc(),
                          iterfield = ['in_file1', 'in_file2'],
                          name='lhMaskGamma')
    lhMaskGamma.inputs.action = 'masked'
    lhMaskGamma.inputs.out_file = 'lh.masked_gamma.mgh'
    wf.connect(lhSBMglmfit, 'gamma_file', lhMaskGamma, 'in_file1')
    if correction_method == 'FDR':
        wf.connect(mriFDR, 'out_file1', lhMaskGamma, 'in_file2')
        
    rhMaskGamma = MapNode(MRIsCalc(),
                          iterfield = ['in_file1', 'in_file2'],
                          name='rhMaskGamma')
    rhMaskGamma.inputs.action = 'masked'
    rhMaskGamma.inputs.out_file = 'rh.masked_gamma.mgh'
    wf.connect(rhSBMglmfit, 'gamma_file', rhMaskGamma, 'in_file1')
    if correction_method == 'FDR':
        wf.connect(mriFDR, 'out_file2', rhMaskGamma, 'in_file2')
        
    # Plot masked gamma 
    plotCorrectedG = MapNode(Function(input_names=plot_stat_inputs,
                                      output_name=['out_file'],
                                      function=plot_surf_stat),
                             iterfield=['lh_stat_map', 'rh_stat_map', 'out_fname'],
                             name='plotCorrectedG')
    plotCorrectedG.inputs.lh_bg_map = lh_bg_map
    plotCorrectedG.inputs.rh_bg_map = rh_bg_map
    plotCorrectedG.inputs.cmap = 'jet'
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'lh.')),
               plotCorrectedG, 'lh_surf')
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'rh.')),
               plotCorrectedG, 'rh_surf')
    wf.connect(fileList, ('contrast_files',  makeFStringElementFromFnameList, '.mtx', '_masked_gamma_map.png', True),
               plotCorrectedG, 'out_fname')
    wf.connect(lhMaskGamma, 'out_file', plotCorrectedG, 'lh_stat_map')
    wf.connect(rhMaskGamma, 'out_file', plotCorrectedG, 'rh_stat_map')
    
    # Plot thresholded P
    plotCorrectedP = MapNode(Function(input_names=plot_stat_inputs,
                                      output_name=['out_file'],
                                      function=plot_surf_stat),
                             iterfield=['lh_stat_map', 'rh_stat_map',
                                        'threshold', 'out_fname'],
                             name='plotCorrectedP')
    plotCorrectedP.inputs.lh_bg_map = op.join(fs_subjects_dir, target_atlas, 'surf', 'lh.sulc')
    plotCorrectedP.inputs.rh_bg_map = op.join(fs_subjects_dir, target_atlas, 'surf', 'rh.sulc')
    plotCorrectedP.inputs.upper_lim = 10.0
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'lh.')),
               plotCorrectedP, 'lh_surf')
    wf.connect(plotSurfList, ('surf', prependString, op.join(fs_subjects_dir, target_atlas, 'surf', 'rh.')),
               plotCorrectedP, 'rh_surf')
    wf.connect(lhSBMglmfit, 'sig_file', plotCorrectedP, 'lh_stat_map')
    wf.connect(rhSBMglmfit, 'sig_file', plotCorrectedP, 'rh_stat_map')
    wf.connect(fileList, ('contrast_files', makeFStringElementFromFnameList, '.mtx', '_corrected_p_map.png', True),
               plotCorrectedP, 'out_fname')
    if correction_method == 'FDR':
        wf.connect(mriFDR, 'out_thr_file', plotCorrectedP, 'threshold')
    
    
#    # Datasink
#    datasink = Node(DataSink(base_directory=base_dir,
#                             container='%sSink' % name),
#                    name='Datasink')
#    
#    glm_outputs = ['gamma_file', 'gamma_var_file', 'sig_file', 'ftest_file']
#    for out in glm_outputs:
#        wf.connect(lhSBMglmfit, out, datasink, 'lhSBMglm_{}'.format(out))
#        wf.connect(rhSBMglmfit, out, datasink, 'rhSBMglm_{}'.format(out))
#    
#    if correction_method == 'FDR':
#        wf.connect(mriFDR, 'out_file1', datasink, 'lhSBM_fdr_corrected_sig')
#        wf.connect(mriFDR, 'out_file2', datasink, 'rhSBM_fdr_corrected_sig')
#        
#    wf.connect(lhMaskGamma, 'out_file', datasink, 'lhSBM_masked_gamma')
#    wf.connect(rhMaskGamma, 'out_file', datasink, 'rhSBM_masked_gamma')
#    
#    wf.connect(plotMeanMaps, 'out_file', datasink, 'mean_map_png')  
#    wf.connect(plotUncorrectedG, 'out_file', datasink, 'uncorrected_gamma_png')
#    wf.connect(plotUncorrectedP, 'out_file', datasink, 'uncorrected_p_png')
#    wf.connect(plotCorrectedG, 'out_file', datasink, 'masked_gamma_png')
#    wf.connect(plotCorrectedP, 'out_file', datasink, 'corrected_p_png')
#    
    return wf