Exemple #1
0
def cale(input_dir, output_dir):

    fns = glob(op.join(input_dir, '*.nii.gz'))

    merger = Merge()
    merger.inputs.in_files = fns
    merger.inputs.dimension = 't'
    merger.inputs.merged_file = op.join(output_dir, 'cALE.nii.gz')

    meanimg = MeanImage()
    meanimg.inputs.in_file = op.join(output_dir, 'cALE.nii.gz')
    meanimg.inputs.dimensions = 'T'
    meanimg.inputs.out_file = op.join(output_dir, 'cALE.nii.gz')

    maths = MultiImageMaths()
    maths.inputs.in_file = op.join(output_dir, 'cALE.nii.gz')
    maths.inputs.op_string = '-mul {0}'.format(len(fns))
    maths.inputs.out_file = op.join(output_dir, 'cALE.nii.gz')

    thresh = Threshold()
    thresh.inputs.in_file = op.join(output_dir, 'cALE.nii.gz')
    thresh.inputs.thresh = np.floor(len(fns) / 2)
    thresh.inputs.direction = 'below'
    thresh.inputs.out_file = op.join(
        output_dir, 'cALE_thresh-{0}.nii.gz'.format(np.floor(len(fns) / 2)))
def test_Threshold_outputs():
    output_map = dict(out_file=dict(), )
    outputs = Threshold.output_spec()

    for key, metadata in output_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(outputs.traits()[key], metakey), value
def test_Threshold_outputs():
    output_map = dict(out_file=dict(),
    )
    outputs = Threshold.output_spec()

    for key, metadata in output_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(outputs.traits()[key], metakey), value
def test_Threshold_inputs():
    input_map = dict(
        args=dict(argstr='%s', ),
        direction=dict(usedefault=True, ),
        environ=dict(
            nohash=True,
            usedefault=True,
        ),
        ignore_exception=dict(
            nohash=True,
            usedefault=True,
        ),
        in_file=dict(
            argstr='%s',
            mandatory=True,
            position=2,
        ),
        internal_datatype=dict(
            argstr='-dt %s',
            position=1,
        ),
        nan2zeros=dict(
            argstr='-nan',
            position=3,
        ),
        out_file=dict(
            argstr='%s',
            genfile=True,
            hash_files=False,
            position=-2,
        ),
        output_datatype=dict(
            argstr='-odt %s',
            position=-1,
        ),
        output_type=dict(),
        terminal_output=dict(
            mandatory=True,
            nohash=True,
        ),
        thresh=dict(
            argstr='%s',
            mandatory=True,
            position=4,
        ),
        use_nonzero_voxels=dict(requires=['use_robust_range'], ),
        use_robust_range=dict(),
    )
    inputs = Threshold.input_spec()

    for key, metadata in input_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(inputs.traits()[key], metakey), value
def test_Threshold_inputs():
    input_map = dict(ignore_exception=dict(nohash=True,
    usedefault=True,
    ),
    nan2zeros=dict(position=3,
    argstr='-nan',
    ),
    direction=dict(usedefault=True,
    ),
    out_file=dict(hash_files=False,
    genfile=True,
    position=-2,
    argstr='%s',
    ),
    args=dict(argstr='%s',
    ),
    internal_datatype=dict(position=1,
    argstr='-dt %s',
    ),
    terminal_output=dict(mandatory=True,
    nohash=True,
    ),
    thresh=dict(position=4,
    mandatory=True,
    argstr='%s',
    ),
    use_robust_range=dict(),
    output_type=dict(),
    environ=dict(nohash=True,
    usedefault=True,
    ),
    use_nonzero_voxels=dict(requires=['use_robust_range'],
    ),
    output_datatype=dict(position=-1,
    argstr='-odt %s',
    ),
    in_file=dict(position=2,
    mandatory=True,
    argstr='%s',
    ),
    )
    inputs = Threshold.input_spec()

    for key, metadata in input_map.items():
        for metakey, value in metadata.items():
            yield assert_equal, getattr(inputs.traits()[key], metakey), value
Exemple #6
0
def Lesion_extractor(
    name='Lesion_Extractor',
    wf_name='Test',
    base_dir='/homes_unix/alaurent/',
    input_dir=None,
    subjects=None,
    main=None,
    acc=None,
    atlas='/homes_unix/alaurent/cbstools-public-master/atlases/brain-segmentation-prior3.0/brain-atlas-quant-3.0.8.txt'
):

    wf = Workflow(wf_name)
    wf.base_dir = base_dir

    #file = open(subjects,"r")
    #subjects = file.read().split("\n")
    #file.close()

    # Subject List
    subjectList = Node(IdentityInterface(fields=['subject_id'],
                                         mandatory_inputs=True),
                       name="subList")
    subjectList.iterables = ('subject_id', [
        sub for sub in subjects if sub != '' and sub != '\n'
    ])

    # T1w and FLAIR
    scanList = Node(DataGrabber(infields=['subject_id'],
                                outfields=['T1', 'FLAIR']),
                    name="scanList")
    scanList.inputs.base_directory = input_dir
    scanList.inputs.ignore_exception = False
    scanList.inputs.raise_on_empty = True
    scanList.inputs.sort_filelist = True
    #scanList.inputs.template = '%s/%s.nii'
    #scanList.inputs.template_args = {'T1': [['subject_id','T1*']],
    #                                 'FLAIR': [['subject_id','FLAIR*']]}
    scanList.inputs.template = '%s/anat/%s'
    scanList.inputs.template_args = {
        'T1': [['subject_id', '*_T1w.nii.gz']],
        'FLAIR': [['subject_id', '*_FLAIR.nii.gz']]
    }
    wf.connect(subjectList, "subject_id", scanList, "subject_id")

    #     # T1w and FLAIR
    #     dg = Node(DataGrabber(outfields=['T1', 'FLAIR']), name="T1wFLAIR")
    #     dg.inputs.base_directory = "/homes_unix/alaurent/LesionPipeline"
    #     dg.inputs.template = "%s/NIFTI/*.nii.gz"
    #     dg.inputs.template_args['T1']=[['7']]
    #     dg.inputs.template_args['FLAIR']=[['9']]
    #     dg.inputs.sort_filelist=True

    # Reorient Volume
    T1Conv = Node(Reorient2Std(), name="ReorientVolume")
    T1Conv.inputs.ignore_exception = False
    T1Conv.inputs.terminal_output = 'none'
    T1Conv.inputs.out_file = "T1_reoriented.nii.gz"
    wf.connect(scanList, "T1", T1Conv, "in_file")

    # Reorient Volume (2)
    T2flairConv = Node(Reorient2Std(), name="ReorientVolume2")
    T2flairConv.inputs.ignore_exception = False
    T2flairConv.inputs.terminal_output = 'none'
    T2flairConv.inputs.out_file = "FLAIR_reoriented.nii.gz"
    wf.connect(scanList, "FLAIR", T2flairConv, "in_file")

    # N3 Correction
    T1NUC = Node(N4BiasFieldCorrection(), name="N3Correction")
    T1NUC.inputs.dimension = 3
    T1NUC.inputs.environ = {'NSLOTS': '1'}
    T1NUC.inputs.ignore_exception = False
    T1NUC.inputs.num_threads = 1
    T1NUC.inputs.save_bias = False
    T1NUC.inputs.terminal_output = 'none'
    wf.connect(T1Conv, "out_file", T1NUC, "input_image")

    # N3 Correction (2)
    T2flairNUC = Node(N4BiasFieldCorrection(), name="N3Correction2")
    T2flairNUC.inputs.dimension = 3
    T2flairNUC.inputs.environ = {'NSLOTS': '1'}
    T2flairNUC.inputs.ignore_exception = False
    T2flairNUC.inputs.num_threads = 1
    T2flairNUC.inputs.save_bias = False
    T2flairNUC.inputs.terminal_output = 'none'
    wf.connect(T2flairConv, "out_file", T2flairNUC, "input_image")
    '''
    #####################
    ### PRE-NORMALIZE ###
    #####################
    To make sure there's no outlier values (negative, or really high) to offset the initialization steps
    '''

    # Intensity Range Normalization
    getMaxT1NUC = Node(ImageStats(op_string='-r'), name="getMaxT1NUC")
    wf.connect(T1NUC, 'output_image', getMaxT1NUC, 'in_file')

    T1NUCirn = Node(AbcImageMaths(), name="IntensityNormalization")
    T1NUCirn.inputs.op_string = "-div"
    T1NUCirn.inputs.out_file = "normT1.nii.gz"
    wf.connect(T1NUC, 'output_image', T1NUCirn, 'in_file')
    wf.connect(getMaxT1NUC, ('out_stat', getElementFromList, 1), T1NUCirn,
               "op_value")

    # Intensity Range Normalization (2)
    getMaxT2NUC = Node(ImageStats(op_string='-r'), name="getMaxT2")
    wf.connect(T2flairNUC, 'output_image', getMaxT2NUC, 'in_file')

    T2NUCirn = Node(AbcImageMaths(), name="IntensityNormalization2")
    T2NUCirn.inputs.op_string = "-div"
    T2NUCirn.inputs.out_file = "normT2.nii.gz"
    wf.connect(T2flairNUC, 'output_image', T2NUCirn, 'in_file')
    wf.connect(getMaxT2NUC, ('out_stat', getElementFromList, 1), T2NUCirn,
               "op_value")
    '''
    ########################
    #### COREGISTRATION ####
    ########################
    '''

    # Optimized Automated Registration
    T2flairCoreg = Node(FLIRT(), name="OptimizedAutomatedRegistration")
    T2flairCoreg.inputs.output_type = 'NIFTI_GZ'
    wf.connect(T2NUCirn, "out_file", T2flairCoreg, "in_file")
    wf.connect(T1NUCirn, "out_file", T2flairCoreg, "reference")
    '''    
    #########################
    #### SKULL-STRIPPING ####
    #########################
    '''

    # SPECTRE
    T1ss = Node(BET(), name="SPECTRE")
    T1ss.inputs.frac = 0.45  #0.4
    T1ss.inputs.mask = True
    T1ss.inputs.outline = True
    T1ss.inputs.robust = True
    wf.connect(T1NUCirn, "out_file", T1ss, "in_file")

    # Image Calculator
    T2ss = Node(ApplyMask(), name="ImageCalculator")
    wf.connect(T1ss, "mask_file", T2ss, "mask_file")
    wf.connect(T2flairCoreg, "out_file", T2ss, "in_file")
    '''
    ####################################
    #### 2nd LAYER OF N3 CORRECTION ####
    ####################################
    This time without the skull: there were some significant amounts of inhomogeneities leftover.
    '''

    # N3 Correction (3)
    T1ssNUC = Node(N4BiasFieldCorrection(), name="N3Correction3")
    T1ssNUC.inputs.dimension = 3
    T1ssNUC.inputs.environ = {'NSLOTS': '1'}
    T1ssNUC.inputs.ignore_exception = False
    T1ssNUC.inputs.num_threads = 1
    T1ssNUC.inputs.save_bias = False
    T1ssNUC.inputs.terminal_output = 'none'
    wf.connect(T1ss, "out_file", T1ssNUC, "input_image")

    # N3 Correction (4)
    T2ssNUC = Node(N4BiasFieldCorrection(), name="N3Correction4")
    T2ssNUC.inputs.dimension = 3
    T2ssNUC.inputs.environ = {'NSLOTS': '1'}
    T2ssNUC.inputs.ignore_exception = False
    T2ssNUC.inputs.num_threads = 1
    T2ssNUC.inputs.save_bias = False
    T2ssNUC.inputs.terminal_output = 'none'
    wf.connect(T2ss, "out_file", T2ssNUC, "input_image")
    '''
    ####################################
    ####    NORMALIZE FOR MGDM      ####
    ####################################
    This normalization is a bit aggressive: only useful to have a 
    cropped dynamic range into MGDM, but possibly harmful to further 
    processing, so the unprocessed images are passed to the subsequent steps.
    '''

    # Intensity Range Normalization
    getMaxT1ssNUC = Node(ImageStats(op_string='-r'), name="getMaxT1ssNUC")
    wf.connect(T1ssNUC, 'output_image', getMaxT1ssNUC, 'in_file')

    T1ssNUCirn = Node(AbcImageMaths(), name="IntensityNormalization3")
    T1ssNUCirn.inputs.op_string = "-div"
    T1ssNUCirn.inputs.out_file = "normT1ss.nii.gz"
    wf.connect(T1ssNUC, 'output_image', T1ssNUCirn, 'in_file')
    wf.connect(getMaxT1ssNUC, ('out_stat', getElementFromList, 1), T1ssNUCirn,
               "op_value")

    # Intensity Range Normalization (2)
    getMaxT2ssNUC = Node(ImageStats(op_string='-r'), name="getMaxT2ssNUC")
    wf.connect(T2ssNUC, 'output_image', getMaxT2ssNUC, 'in_file')

    T2ssNUCirn = Node(AbcImageMaths(), name="IntensityNormalization4")
    T2ssNUCirn.inputs.op_string = "-div"
    T2ssNUCirn.inputs.out_file = "normT2ss.nii.gz"
    wf.connect(T2ssNUC, 'output_image', T2ssNUCirn, 'in_file')
    wf.connect(getMaxT2ssNUC, ('out_stat', getElementFromList, 1), T2ssNUCirn,
               "op_value")
    '''
    ####################################
    ####      ESTIMATE CSF PV       ####
    ####################################
    Here we try to get a better handle on CSF voxels to help the segmentation step
    '''

    # Recursive Ridge Diffusion
    CSF_pv = Node(RecursiveRidgeDiffusion(), name='estimate_CSF_pv')
    CSF_pv.plugin_args = {'sbatch_args': '--mem 6000'}
    CSF_pv.inputs.ridge_intensities = "dark"
    CSF_pv.inputs.ridge_filter = "2D"
    CSF_pv.inputs.orientation = "undefined"
    CSF_pv.inputs.ang_factor = 1.0
    CSF_pv.inputs.min_scale = 0
    CSF_pv.inputs.max_scale = 3
    CSF_pv.inputs.propagation_model = "diffusion"
    CSF_pv.inputs.diffusion_factor = 0.5
    CSF_pv.inputs.similarity_scale = 0.1
    CSF_pv.inputs.neighborhood_size = 4
    CSF_pv.inputs.max_iter = 100
    CSF_pv.inputs.max_diff = 0.001
    CSF_pv.inputs.save_data = True
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, CSF_pv.name),
        CSF_pv, 'output_dir')
    wf.connect(T1ssNUCirn, 'out_file', CSF_pv, 'input_image')
    '''
    ####################################
    ####            MGDM            ####
    ####################################
    '''

    # Multi-contrast Brain Segmentation
    MGDM = Node(MGDMSegmentation(), name='MGDM')
    MGDM.plugin_args = {'sbatch_args': '--mem 7000'}
    MGDM.inputs.contrast_type1 = "Mprage3T"
    MGDM.inputs.contrast_type2 = "FLAIR3T"
    MGDM.inputs.contrast_type3 = "PVDURA"
    MGDM.inputs.save_data = True
    MGDM.inputs.atlas_file = atlas
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, MGDM.name), MGDM,
        'output_dir')
    wf.connect(T1ssNUCirn, 'out_file', MGDM, 'contrast_image1')
    wf.connect(T2ssNUCirn, 'out_file', MGDM, 'contrast_image2')
    wf.connect(CSF_pv, 'ridge_pv', MGDM, 'contrast_image3')

    # Enhance Region Contrast
    ERC = Node(EnhanceRegionContrast(), name='ERC')
    ERC.plugin_args = {'sbatch_args': '--mem 7000'}
    ERC.inputs.enhanced_region = "crwm"
    ERC.inputs.contrast_background = "crgm"
    ERC.inputs.partial_voluming_distance = 2.0
    ERC.inputs.save_data = True
    ERC.inputs.atlas_file = atlas
    wf.connect(subjectList,
               ('subject_id', createOutputDir, wf.base_dir, wf.name, ERC.name),
               ERC, 'output_dir')
    wf.connect(T1ssNUC, 'output_image', ERC, 'intensity_image')
    wf.connect(MGDM, 'segmentation', ERC, 'segmentation_image')
    wf.connect(MGDM, 'distance', ERC, 'levelset_boundary_image')

    # Enhance Region Contrast (2)
    ERC2 = Node(EnhanceRegionContrast(), name='ERC2')
    ERC2.plugin_args = {'sbatch_args': '--mem 7000'}
    ERC2.inputs.enhanced_region = "crwm"
    ERC2.inputs.contrast_background = "crgm"
    ERC2.inputs.partial_voluming_distance = 2.0
    ERC2.inputs.save_data = True
    ERC2.inputs.atlas_file = atlas
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, ERC2.name), ERC2,
        'output_dir')
    wf.connect(T2ssNUC, 'output_image', ERC2, 'intensity_image')
    wf.connect(MGDM, 'segmentation', ERC2, 'segmentation_image')
    wf.connect(MGDM, 'distance', ERC2, 'levelset_boundary_image')

    # Define Multi-Region Priors
    DMRP = Node(DefineMultiRegionPriors(), name='DefineMultRegPriors')
    DMRP.plugin_args = {'sbatch_args': '--mem 6000'}
    #DMRP.inputs.defined_region = "ventricle-horns"
    #DMRP.inputs.definition_method = "closest-distance"
    DMRP.inputs.distance_offset = 3.0
    DMRP.inputs.save_data = True
    DMRP.inputs.atlas_file = atlas
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, DMRP.name), DMRP,
        'output_dir')
    wf.connect(MGDM, 'segmentation', DMRP, 'segmentation_image')
    wf.connect(MGDM, 'distance', DMRP, 'levelset_boundary_image')
    '''
    ###############################################
    ####      REMOVE VENTRICLE POSTERIOR       ####
    ###############################################
    Due to topology constraints, the ventricles are often not fully segmented:
    here add back all ventricle voxels from the posterior probability (without the topology constraints)
    '''

    # Posterior label
    PostLabel = Node(Split(), name='PosteriorLabel')
    PostLabel.inputs.dimension = "t"
    wf.connect(MGDM, 'labels', PostLabel, 'in_file')

    # Posterior proba
    PostProba = Node(Split(), name='PosteriorProba')
    PostProba.inputs.dimension = "t"
    wf.connect(MGDM, 'memberships', PostProba, 'in_file')

    # Threshold binary mask : ventricle label part 1
    VentLabel1 = Node(Threshold(), name="VentricleLabel1")
    VentLabel1.inputs.thresh = 10.5
    VentLabel1.inputs.direction = "below"
    wf.connect(PostLabel, ("out_files", getFirstElement), VentLabel1,
               "in_file")

    # Threshold binary mask : ventricle label part 2
    VentLabel2 = Node(Threshold(), name="VentricleLabel2")
    VentLabel2.inputs.thresh = 13.5
    VentLabel2.inputs.direction = "above"
    wf.connect(VentLabel1, "out_file", VentLabel2, "in_file")

    # Image calculator : ventricle proba
    VentProba = Node(ImageMaths(), name="VentricleProba")
    VentProba.inputs.op_string = "-mul"
    VentProba.inputs.out_file = "ventproba.nii.gz"
    wf.connect(PostProba, ("out_files", getFirstElement), VentProba, "in_file")
    wf.connect(VentLabel2, "out_file", VentProba, "in_file2")

    # Image calculator : remove inter ventricles
    RmInterVent = Node(ImageMaths(), name="RemoveInterVent")
    RmInterVent.inputs.op_string = "-sub"
    RmInterVent.inputs.out_file = "rmintervent.nii.gz"
    wf.connect(ERC, "region_pv", RmInterVent, "in_file")
    wf.connect(DMRP, "inter_ventricular_pv", RmInterVent, "in_file2")

    # Image calculator : add horns
    AddHorns = Node(ImageMaths(), name="AddHorns")
    AddHorns.inputs.op_string = "-add"
    AddHorns.inputs.out_file = "rmvent.nii.gz"
    wf.connect(RmInterVent, "out_file", AddHorns, "in_file")
    wf.connect(DMRP, "ventricular_horns_pv", AddHorns, "in_file2")

    # Image calculator : remove ventricles
    RmVent = Node(ImageMaths(), name="RemoveVentricles")
    RmVent.inputs.op_string = "-sub"
    RmVent.inputs.out_file = "rmvent.nii.gz"
    wf.connect(AddHorns, "out_file", RmVent, "in_file")
    wf.connect(VentProba, "out_file", RmVent, "in_file2")

    # Image calculator : remove internal capsule
    RmIC = Node(ImageMaths(), name="RemoveInternalCap")
    RmIC.inputs.op_string = "-sub"
    RmIC.inputs.out_file = "rmic.nii.gz"
    wf.connect(RmVent, "out_file", RmIC, "in_file")
    wf.connect(DMRP, "internal_capsule_pv", RmIC, "in_file2")

    # Intensity Range Normalization (3)
    getMaxRmIC = Node(ImageStats(op_string='-r'), name="getMaxRmIC")
    wf.connect(RmIC, 'out_file', getMaxRmIC, 'in_file')

    RmICirn = Node(AbcImageMaths(), name="IntensityNormalization5")
    RmICirn.inputs.op_string = "-div"
    RmICirn.inputs.out_file = "normRmIC.nii.gz"
    wf.connect(RmIC, 'out_file', RmICirn, 'in_file')
    wf.connect(getMaxRmIC, ('out_stat', getElementFromList, 1), RmICirn,
               "op_value")

    # Probability To Levelset : WM orientation
    WM_Orient = Node(ProbabilityToLevelset(), name='WM_Orientation')
    WM_Orient.plugin_args = {'sbatch_args': '--mem 6000'}
    WM_Orient.inputs.save_data = True
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, WM_Orient.name),
        WM_Orient, 'output_dir')
    wf.connect(RmICirn, 'out_file', WM_Orient, 'probability_image')

    # Recursive Ridge Diffusion : PVS in WM only
    WM_pvs = Node(RecursiveRidgeDiffusion(), name='PVS_in_WM')
    WM_pvs.plugin_args = {'sbatch_args': '--mem 6000'}
    WM_pvs.inputs.ridge_intensities = "bright"
    WM_pvs.inputs.ridge_filter = "1D"
    WM_pvs.inputs.orientation = "orthogonal"
    WM_pvs.inputs.ang_factor = 1.0
    WM_pvs.inputs.min_scale = 0
    WM_pvs.inputs.max_scale = 3
    WM_pvs.inputs.propagation_model = "diffusion"
    WM_pvs.inputs.diffusion_factor = 1.0
    WM_pvs.inputs.similarity_scale = 1.0
    WM_pvs.inputs.neighborhood_size = 2
    WM_pvs.inputs.max_iter = 100
    WM_pvs.inputs.max_diff = 0.001
    WM_pvs.inputs.save_data = True
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, WM_pvs.name),
        WM_pvs, 'output_dir')
    wf.connect(ERC, 'background_proba', WM_pvs, 'input_image')
    wf.connect(WM_Orient, 'levelset', WM_pvs, 'surface_levelset')
    wf.connect(RmICirn, 'out_file', WM_pvs, 'loc_prior')

    # Extract Lesions : extract WM PVS
    extract_WM_pvs = Node(LesionExtraction(), name='ExtractPVSfromWM')
    extract_WM_pvs.plugin_args = {'sbatch_args': '--mem 6000'}
    extract_WM_pvs.inputs.gm_boundary_partial_vol_dist = 1.0
    extract_WM_pvs.inputs.csf_boundary_partial_vol_dist = 3.0
    extract_WM_pvs.inputs.lesion_clust_dist = 1.0
    extract_WM_pvs.inputs.prob_min_thresh = 0.1
    extract_WM_pvs.inputs.prob_max_thresh = 0.33
    extract_WM_pvs.inputs.small_lesion_size = 4.0
    extract_WM_pvs.inputs.save_data = True
    extract_WM_pvs.inputs.atlas_file = atlas
    wf.connect(subjectList, ('subject_id', createOutputDir, wf.base_dir,
                             wf.name, extract_WM_pvs.name), extract_WM_pvs,
               'output_dir')
    wf.connect(WM_pvs, 'propagation', extract_WM_pvs, 'probability_image')
    wf.connect(MGDM, 'segmentation', extract_WM_pvs, 'segmentation_image')
    wf.connect(MGDM, 'distance', extract_WM_pvs, 'levelset_boundary_image')
    wf.connect(RmICirn, 'out_file', extract_WM_pvs, 'location_prior_image')
    '''
    2nd branch
    '''

    # Image calculator : internal capsule witout ventricules
    ICwoVent = Node(ImageMaths(), name="ICWithoutVentricules")
    ICwoVent.inputs.op_string = "-sub"
    ICwoVent.inputs.out_file = "icwovent.nii.gz"
    wf.connect(DMRP, "internal_capsule_pv", ICwoVent, "in_file")
    wf.connect(DMRP, "inter_ventricular_pv", ICwoVent, "in_file2")

    # Image calculator : remove ventricles IC
    RmVentIC = Node(ImageMaths(), name="RmVentIC")
    RmVentIC.inputs.op_string = "-sub"
    RmVentIC.inputs.out_file = "RmVentIC.nii.gz"
    wf.connect(ICwoVent, "out_file", RmVentIC, "in_file")
    wf.connect(VentProba, "out_file", RmVentIC, "in_file2")

    # Intensity Range Normalization (4)
    getMaxRmVentIC = Node(ImageStats(op_string='-r'), name="getMaxRmVentIC")
    wf.connect(RmVentIC, 'out_file', getMaxRmVentIC, 'in_file')

    RmVentICirn = Node(AbcImageMaths(), name="IntensityNormalization6")
    RmVentICirn.inputs.op_string = "-div"
    RmVentICirn.inputs.out_file = "normRmVentIC.nii.gz"
    wf.connect(RmVentIC, 'out_file', RmVentICirn, 'in_file')
    wf.connect(getMaxRmVentIC, ('out_stat', getElementFromList, 1),
               RmVentICirn, "op_value")

    # Probability To Levelset : IC orientation
    IC_Orient = Node(ProbabilityToLevelset(), name='IC_Orientation')
    IC_Orient.plugin_args = {'sbatch_args': '--mem 6000'}
    IC_Orient.inputs.save_data = True
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, IC_Orient.name),
        IC_Orient, 'output_dir')
    wf.connect(RmVentICirn, 'out_file', IC_Orient, 'probability_image')

    # Recursive Ridge Diffusion : PVS in IC only
    IC_pvs = Node(RecursiveRidgeDiffusion(), name='RecursiveRidgeDiffusion2')
    IC_pvs.plugin_args = {'sbatch_args': '--mem 6000'}
    IC_pvs.inputs.ridge_intensities = "bright"
    IC_pvs.inputs.ridge_filter = "1D"
    IC_pvs.inputs.orientation = "undefined"
    IC_pvs.inputs.ang_factor = 1.0
    IC_pvs.inputs.min_scale = 0
    IC_pvs.inputs.max_scale = 3
    IC_pvs.inputs.propagation_model = "diffusion"
    IC_pvs.inputs.diffusion_factor = 1.0
    IC_pvs.inputs.similarity_scale = 1.0
    IC_pvs.inputs.neighborhood_size = 2
    IC_pvs.inputs.max_iter = 100
    IC_pvs.inputs.max_diff = 0.001
    IC_pvs.inputs.save_data = True
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, IC_pvs.name),
        IC_pvs, 'output_dir')
    wf.connect(ERC, 'background_proba', IC_pvs, 'input_image')
    wf.connect(IC_Orient, 'levelset', IC_pvs, 'surface_levelset')
    wf.connect(RmVentICirn, 'out_file', IC_pvs, 'loc_prior')

    # Extract Lesions : extract IC PVS
    extract_IC_pvs = Node(LesionExtraction(), name='ExtractPVSfromIC')
    extract_IC_pvs.plugin_args = {'sbatch_args': '--mem 6000'}
    extract_IC_pvs.inputs.gm_boundary_partial_vol_dist = 1.0
    extract_IC_pvs.inputs.csf_boundary_partial_vol_dist = 4.0
    extract_IC_pvs.inputs.lesion_clust_dist = 1.0
    extract_IC_pvs.inputs.prob_min_thresh = 0.25
    extract_IC_pvs.inputs.prob_max_thresh = 0.5
    extract_IC_pvs.inputs.small_lesion_size = 4.0
    extract_IC_pvs.inputs.save_data = True
    extract_IC_pvs.inputs.atlas_file = atlas
    wf.connect(subjectList, ('subject_id', createOutputDir, wf.base_dir,
                             wf.name, extract_IC_pvs.name), extract_IC_pvs,
               'output_dir')
    wf.connect(IC_pvs, 'propagation', extract_IC_pvs, 'probability_image')
    wf.connect(MGDM, 'segmentation', extract_IC_pvs, 'segmentation_image')
    wf.connect(MGDM, 'distance', extract_IC_pvs, 'levelset_boundary_image')
    wf.connect(RmVentICirn, 'out_file', extract_IC_pvs, 'location_prior_image')
    '''
    3rd branch
    '''

    # Image calculator :
    RmInter = Node(ImageMaths(), name="RemoveInterVentricules")
    RmInter.inputs.op_string = "-sub"
    RmInter.inputs.out_file = "rminter.nii.gz"
    wf.connect(ERC2, 'region_pv', RmInter, "in_file")
    wf.connect(DMRP, "inter_ventricular_pv", RmInter, "in_file2")

    # Image calculator :
    AddVentHorns = Node(ImageMaths(), name="AddVentHorns")
    AddVentHorns.inputs.op_string = "-add"
    AddVentHorns.inputs.out_file = "rminter.nii.gz"
    wf.connect(RmInter, 'out_file', AddVentHorns, "in_file")
    wf.connect(DMRP, "ventricular_horns_pv", AddVentHorns, "in_file2")

    # Intensity Range Normalization (5)
    getMaxAddVentHorns = Node(ImageStats(op_string='-r'),
                              name="getMaxAddVentHorns")
    wf.connect(AddVentHorns, 'out_file', getMaxAddVentHorns, 'in_file')

    AddVentHornsirn = Node(AbcImageMaths(), name="IntensityNormalization7")
    AddVentHornsirn.inputs.op_string = "-div"
    AddVentHornsirn.inputs.out_file = "normAddVentHorns.nii.gz"
    wf.connect(AddVentHorns, 'out_file', AddVentHornsirn, 'in_file')
    wf.connect(getMaxAddVentHorns, ('out_stat', getElementFromList, 1),
               AddVentHornsirn, "op_value")

    # Extract Lesions : extract White Matter Hyperintensities
    extract_WMH = Node(LesionExtraction(), name='Extract_WMH')
    extract_WMH.plugin_args = {'sbatch_args': '--mem 6000'}
    extract_WMH.inputs.gm_boundary_partial_vol_dist = 1.0
    extract_WMH.inputs.csf_boundary_partial_vol_dist = 2.0
    extract_WMH.inputs.lesion_clust_dist = 1.0
    extract_WMH.inputs.prob_min_thresh = 0.84
    extract_WMH.inputs.prob_max_thresh = 0.84
    extract_WMH.inputs.small_lesion_size = 4.0
    extract_WMH.inputs.save_data = True
    extract_WMH.inputs.atlas_file = atlas
    wf.connect(subjectList, ('subject_id', createOutputDir, wf.base_dir,
                             wf.name, extract_WMH.name), extract_WMH,
               'output_dir')
    wf.connect(ERC2, 'background_proba', extract_WMH, 'probability_image')
    wf.connect(MGDM, 'segmentation', extract_WMH, 'segmentation_image')
    wf.connect(MGDM, 'distance', extract_WMH, 'levelset_boundary_image')
    wf.connect(AddVentHornsirn, 'out_file', extract_WMH,
               'location_prior_image')

    #===========================================================================
    # extract_WMH2 = extract_WMH.clone(name='Extract_WMH2')
    # extract_WMH2.inputs.gm_boundary_partial_vol_dist = 2.0
    # wf.connect(subjectList,('subject_id',createOutputDir,wf.base_dir,wf.name,extract_WMH2.name),extract_WMH2,'output_dir')
    # wf.connect(ERC2,'background_proba',extract_WMH2,'probability_image')
    # wf.connect(MGDM,'segmentation',extract_WMH2,'segmentation_image')
    # wf.connect(MGDM,'distance',extract_WMH2,'levelset_boundary_image')
    # wf.connect(AddVentHornsirn,'out_file',extract_WMH2,'location_prior_image')
    #
    # extract_WMH3 = extract_WMH.clone(name='Extract_WMH3')
    # extract_WMH3.inputs.gm_boundary_partial_vol_dist = 3.0
    # wf.connect(subjectList,('subject_id',createOutputDir,wf.base_dir,wf.name,extract_WMH3.name),extract_WMH3,'output_dir')
    # wf.connect(ERC2,'background_proba',extract_WMH3,'probability_image')
    # wf.connect(MGDM,'segmentation',extract_WMH3,'segmentation_image')
    # wf.connect(MGDM,'distance',extract_WMH3,'levelset_boundary_image')
    # wf.connect(AddVentHornsirn,'out_file',extract_WMH3,'location_prior_image')
    #===========================================================================
    '''
    ####################################
    ####     FINDING SMALL WMHs     ####
    ####################################
    Small round WMHs near the cortex are often missed by the main algorithm, 
    so we're adding this one that takes care of them.
    '''

    # Recursive Ridge Diffusion : round WMH detection
    round_WMH = Node(RecursiveRidgeDiffusion(), name='round_WMH')
    round_WMH.plugin_args = {'sbatch_args': '--mem 6000'}
    round_WMH.inputs.ridge_intensities = "bright"
    round_WMH.inputs.ridge_filter = "0D"
    round_WMH.inputs.orientation = "undefined"
    round_WMH.inputs.ang_factor = 1.0
    round_WMH.inputs.min_scale = 1
    round_WMH.inputs.max_scale = 4
    round_WMH.inputs.propagation_model = "none"
    round_WMH.inputs.diffusion_factor = 1.0
    round_WMH.inputs.similarity_scale = 0.1
    round_WMH.inputs.neighborhood_size = 4
    round_WMH.inputs.max_iter = 100
    round_WMH.inputs.max_diff = 0.001
    round_WMH.inputs.save_data = True
    wf.connect(
        subjectList,
        ('subject_id', createOutputDir, wf.base_dir, wf.name, round_WMH.name),
        round_WMH, 'output_dir')
    wf.connect(ERC2, 'background_proba', round_WMH, 'input_image')
    wf.connect(AddVentHornsirn, 'out_file', round_WMH, 'loc_prior')

    # Extract Lesions : extract round WMH
    extract_round_WMH = Node(LesionExtraction(), name='Extract_round_WMH')
    extract_round_WMH.plugin_args = {'sbatch_args': '--mem 6000'}
    extract_round_WMH.inputs.gm_boundary_partial_vol_dist = 1.0
    extract_round_WMH.inputs.csf_boundary_partial_vol_dist = 2.0
    extract_round_WMH.inputs.lesion_clust_dist = 1.0
    extract_round_WMH.inputs.prob_min_thresh = 0.33
    extract_round_WMH.inputs.prob_max_thresh = 0.33
    extract_round_WMH.inputs.small_lesion_size = 6.0
    extract_round_WMH.inputs.save_data = True
    extract_round_WMH.inputs.atlas_file = atlas
    wf.connect(subjectList, ('subject_id', createOutputDir, wf.base_dir,
                             wf.name, extract_round_WMH.name),
               extract_round_WMH, 'output_dir')
    wf.connect(round_WMH, 'ridge_pv', extract_round_WMH, 'probability_image')
    wf.connect(MGDM, 'segmentation', extract_round_WMH, 'segmentation_image')
    wf.connect(MGDM, 'distance', extract_round_WMH, 'levelset_boundary_image')
    wf.connect(AddVentHornsirn, 'out_file', extract_round_WMH,
               'location_prior_image')

    #===========================================================================
    # extract_round_WMH2 = extract_round_WMH.clone(name='Extract_round_WMH2')
    # extract_round_WMH2.inputs.gm_boundary_partial_vol_dist = 2.0
    # wf.connect(subjectList,('subject_id',createOutputDir,wf.base_dir,wf.name,extract_round_WMH2.name),extract_round_WMH2,'output_dir')
    # wf.connect(round_WMH,'ridge_pv',extract_round_WMH2,'probability_image')
    # wf.connect(MGDM,'segmentation',extract_round_WMH2,'segmentation_image')
    # wf.connect(MGDM,'distance',extract_round_WMH2,'levelset_boundary_image')
    # wf.connect(AddVentHornsirn,'out_file',extract_round_WMH2,'location_prior_image')
    #
    # extract_round_WMH3 = extract_round_WMH.clone(name='Extract_round_WMH3')
    # extract_round_WMH3.inputs.gm_boundary_partial_vol_dist = 2.0
    # wf.connect(subjectList,('subject_id',createOutputDir,wf.base_dir,wf.name,extract_round_WMH3.name),extract_round_WMH3,'output_dir')
    # wf.connect(round_WMH,'ridge_pv',extract_round_WMH3,'probability_image')
    # wf.connect(MGDM,'segmentation',extract_round_WMH3,'segmentation_image')
    # wf.connect(MGDM,'distance',extract_round_WMH3,'levelset_boundary_image')
    # wf.connect(AddVentHornsirn,'out_file',extract_round_WMH3,'location_prior_image')
    #===========================================================================
    '''
    ####################################
    ####     COMBINE BOTH TYPES     ####
    ####################################
    Small round WMHs and regular WMH together before thresholding
    +
    PVS from white matter and internal capsule
    '''

    # Image calculator : WM + IC DVRS
    DVRS = Node(ImageMaths(), name="DVRS")
    DVRS.inputs.op_string = "-max"
    DVRS.inputs.out_file = "DVRS_map.nii.gz"
    wf.connect(extract_WM_pvs, 'lesion_score', DVRS, "in_file")
    wf.connect(extract_IC_pvs, "lesion_score", DVRS, "in_file2")

    # Image calculator : WMH + round
    WMH = Node(ImageMaths(), name="WMH")
    WMH.inputs.op_string = "-max"
    WMH.inputs.out_file = "WMH_map.nii.gz"
    wf.connect(extract_WMH, 'lesion_score', WMH, "in_file")
    wf.connect(extract_round_WMH, "lesion_score", WMH, "in_file2")

    #===========================================================================
    # WMH2 = Node(ImageMaths(), name="WMH2")
    # WMH2.inputs.op_string = "-max"
    # WMH2.inputs.out_file = "WMH2_map.nii.gz"
    # wf.connect(extract_WMH2,'lesion_score',WMH2,"in_file")
    # wf.connect(extract_round_WMH2,"lesion_score", WMH2, "in_file2")
    #
    # WMH3 = Node(ImageMaths(), name="WMH3")
    # WMH3.inputs.op_string = "-max"
    # WMH3.inputs.out_file = "WMH3_map.nii.gz"
    # wf.connect(extract_WMH3,'lesion_score',WMH3,"in_file")
    # wf.connect(extract_round_WMH3,"lesion_score", WMH3, "in_file2")
    #===========================================================================

    # Image calculator : multiply by boundnary partial volume
    WMH_mul = Node(ImageMaths(), name="WMH_mul")
    WMH_mul.inputs.op_string = "-mul"
    WMH_mul.inputs.out_file = "final_mask.nii.gz"
    wf.connect(WMH, "out_file", WMH_mul, "in_file")
    wf.connect(MGDM, "distance", WMH_mul, "in_file2")

    #===========================================================================
    # WMH2_mul = Node(ImageMaths(), name="WMH2_mul")
    # WMH2_mul.inputs.op_string = "-mul"
    # WMH2_mul.inputs.out_file = "final_mask.nii.gz"
    # wf.connect(WMH2,"out_file", WMH2_mul,"in_file")
    # wf.connect(MGDM,"distance", WMH2_mul, "in_file2")
    #
    # WMH3_mul = Node(ImageMaths(), name="WMH3_mul")
    # WMH3_mul.inputs.op_string = "-mul"
    # WMH3_mul.inputs.out_file = "final_mask.nii.gz"
    # wf.connect(WMH3,"out_file", WMH3_mul,"in_file")
    # wf.connect(MGDM,"distance", WMH3_mul, "in_file2")
    #===========================================================================
    '''
    ##########################################
    ####      SEGMENTATION THRESHOLD      ####
    ##########################################
    A threshold of 0.5 is very conservative, because the final lesion score is the product of two probabilities.
    This needs to be optimized to a value between 0.25 and 0.5 to balance false negatives 
    (dominant at 0.5) and false positives (dominant at low values).
    '''

    # Threshold binary mask :
    DVRS_mask = Node(Threshold(), name="DVRS_mask")
    DVRS_mask.inputs.thresh = 0.25
    DVRS_mask.inputs.direction = "below"
    wf.connect(DVRS, "out_file", DVRS_mask, "in_file")

    # Threshold binary mask : 025
    WMH1_025 = Node(Threshold(), name="WMH1_025")
    WMH1_025.inputs.thresh = 0.25
    WMH1_025.inputs.direction = "below"
    wf.connect(WMH_mul, "out_file", WMH1_025, "in_file")

    #===========================================================================
    # WMH2_025 = Node(Threshold(), name="WMH2_025")
    # WMH2_025.inputs.thresh = 0.25
    # WMH2_025.inputs.direction = "below"
    # wf.connect(WMH2_mul,"out_file", WMH2_025, "in_file")
    #
    # WMH3_025 = Node(Threshold(), name="WMH3_025")
    # WMH3_025.inputs.thresh = 0.25
    # WMH3_025.inputs.direction = "below"
    # wf.connect(WMH3_mul,"out_file", WMH3_025, "in_file")
    #===========================================================================

    # Threshold binary mask : 050
    WMH1_050 = Node(Threshold(), name="WMH1_050")
    WMH1_050.inputs.thresh = 0.50
    WMH1_050.inputs.direction = "below"
    wf.connect(WMH_mul, "out_file", WMH1_050, "in_file")

    #===========================================================================
    # WMH2_050 = Node(Threshold(), name="WMH2_050")
    # WMH2_050.inputs.thresh = 0.50
    # WMH2_050.inputs.direction = "below"
    # wf.connect(WMH2_mul,"out_file", WMH2_050, "in_file")
    #
    # WMH3_050 = Node(Threshold(), name="WMH3_050")
    # WMH3_050.inputs.thresh = 0.50
    # WMH3_050.inputs.direction = "below"
    # wf.connect(WMH3_mul,"out_file", WMH3_050, "in_file")
    #===========================================================================

    # Threshold binary mask : 075
    WMH1_075 = Node(Threshold(), name="WMH1_075")
    WMH1_075.inputs.thresh = 0.75
    WMH1_075.inputs.direction = "below"
    wf.connect(WMH_mul, "out_file", WMH1_075, "in_file")

    #===========================================================================
    # WMH2_075 = Node(Threshold(), name="WMH2_075")
    # WMH2_075.inputs.thresh = 0.75
    # WMH2_075.inputs.direction = "below"
    # wf.connect(WMH2_mul,"out_file", WMH2_075, "in_file")
    #
    # WMH3_075 = Node(Threshold(), name="WMH3_075")
    # WMH3_075.inputs.thresh = 0.75
    # WMH3_075.inputs.direction = "below"
    # wf.connect(WMH3_mul,"out_file", WMH3_075, "in_file")
    #===========================================================================

    ## Outputs

    DVRS_Output = Node(IdentityInterface(fields=[
        'mask', 'region', 'lesion_size', 'lesion_proba', 'boundary', 'label',
        'score'
    ]),
                       name='DVRS_Output')
    wf.connect(DVRS_mask, 'out_file', DVRS_Output, 'mask')

    WMH_output = Node(IdentityInterface(fields=[
        'mask1025', 'mask1050', 'mask1075', 'mask2025', 'mask2050', 'mask2075',
        'mask3025', 'mask3050', 'mask3075'
    ]),
                      name='WMH_output')
    wf.connect(WMH1_025, 'out_file', WMH_output, 'mask1025')
    #wf.connect(WMH2_025,'out_file',WMH_output,'mask2025')
    #wf.connect(WMH3_025,'out_file',WMH_output,'mask3025')
    wf.connect(WMH1_050, 'out_file', WMH_output, 'mask1050')
    #wf.connect(WMH2_050,'out_file',WMH_output,'mask2050')
    #wf.connect(WMH3_050,'out_file',WMH_output,'mask3050')
    wf.connect(WMH1_075, 'out_file', WMH_output, 'mask1075')
    #wf.connect(WMH2_075,'out_file',WMH_output,'mask2070')
    #wf.connect(WMH3_075,'out_file',WMH_output,'mask3075')

    return wf
Exemple #7
0
    def segmentation_pipeline(self, **kwargs):  # @UnusedVariable @IgnorePep8

        pipeline = self.create_pipeline(
            name='ute1_segmentation',
            inputs=[DatasetSpec('ute1_registered', nifti_format)],
            outputs=[
                DatasetSpec('air_mask', nifti_gz_format),
                DatasetSpec('bones_mask', nifti_gz_format)
            ],
            desc="Segmentation of the first echo UTE image",
            version=1,
            citations=(spm_cite, matlab_cite),
            **kwargs)

        segmentation = pipeline.create_node(
            NewSegment(),
            name='ute1_registered_segmentation',
            requirements=[matlab2015_req, spm12_req],
            wall_time=480)
        pipeline.connect_input('ute1_registered', segmentation,
                               'channel_files')
        segmentation.inputs.affine_regularization = 'none'
        tissue1 = ((self.tpm_path, 1), 1, (True, False), (False, False))
        tissue2 = ((self.tpm_path, 2), 1, (True, False), (False, False))
        tissue3 = ((self.tpm_path, 3), 2, (True, False), (False, False))
        tissue4 = ((self.tpm_path, 4), 3, (True, False), (False, False))
        tissue5 = ((self.tpm_path, 5), 4, (True, False), (False, False))
        tissue6 = ((self.tpm_path, 6), 3, (True, False), (False, False))
        segmentation.inputs.tissues = [
            tissue1, tissue2, tissue3, tissue4, tissue5, tissue6
        ]

        select_bones_pm = pipeline.create_node(
            Select(),
            name='select_bones_pm_from_SPM_new_segmentation',
            requirements=[],
            wall_time=5)
        pipeline.connect(segmentation, 'native_class_images', select_bones_pm,
                         'inlist')
        select_bones_pm.inputs.index = 3

        select_air_pm = pipeline.create_node(
            Select(),
            name='select_air_pm_from_SPM_new_segmentation',
            requirements=[],
            wall_time=5)

        pipeline.connect(segmentation, 'native_class_images', select_air_pm,
                         'inlist')
        select_air_pm.inputs.index = 5

        threshold_bones = pipeline.create_node(
            Threshold(),
            name='bones_probabilistic_map_thresholding',
            requirements=[fsl5_req],
            wall_time=5)
        pipeline.connect(select_bones_pm, 'out', threshold_bones, 'in_file')
        threshold_bones.inputs.output_type = "NIFTI_GZ"
        threshold_bones.inputs.direction = 'below'
        threshold_bones.inputs.thresh = 0.2

        binarize_bones = pipeline.create_node(
            UnaryMaths(),
            name='bones_probabilistic_map_binarization',
            requirements=[fsl5_req],
            wall_time=5)
        pipeline.connect(threshold_bones, 'out_file', binarize_bones,
                         'in_file')
        binarize_bones.inputs.output_type = "NIFTI_GZ"
        binarize_bones.inputs.operation = 'bin'

        threshold_air = pipeline.create_node(
            Threshold(),
            name='air_probabilistic_maps_thresholding',
            requirements=[fsl5_req],
            wall_time=5)
        pipeline.connect(select_air_pm, 'out', threshold_air, 'in_file')
        threshold_air.inputs.output_type = "NIFTI_GZ"
        threshold_air.inputs.direction = 'below'
        threshold_air.inputs.thresh = 0.1

        binarize_air = pipeline.create_node(
            UnaryMaths(),
            name='air_probabilistic_map_binarization',
            requirements=[fsl5_req],
            wall_time=5)
        pipeline.connect(threshold_air, 'out_file', binarize_air, 'in_file')
        binarize_air.inputs.output_type = "NIFTI_GZ"
        binarize_air.inputs.operation = 'bin'

        pipeline.connect_output('bones_mask', binarize_bones, 'out_file')
        pipeline.connect_output('air_mask', binarize_air, 'out_file')
        pipeline.assert_connected()

        return pipeline
Exemple #8
0
                           curfunc.split('task-')[1].split('_')[0] +
                           '*' +
                           curfunc.split('run-')[1].split('_')[0] +
                           '*' + '*AROMAnoiseICs.csv')[0],
                 delimiter=',').astype('int')),
         in_file=curfunc,
         mask=curmask,
         out_file=tmpAROMA).run()
 if not os.path.isfile(tmpAROMAconf):
     if not os.path.isfile(tmpAROMAwm) or not os.path.isfile(
             tmpAROMAcsf):
         from nipype.interfaces.fsl.maths import Threshold
         from nipype.interfaces.fsl.utils import ImageMeants
         Threshold(
             in_file=cursegm,
             thresh=2.5,
             out_file=tmpAROMAwm,
             args=' -uthr 3.5 -kernel sphere 4 -ero -bin').run()
         Threshold(
             in_file=cursegm,
             thresh=0.5,
             out_file=tmpAROMAcsf,
             args=' -uthr 1.5 -kernel sphere 2 -ero -bin').run()
     wmts = NiftiLabelsMasker(
         labels_img=tmpAROMAwm, detrend=False,
         standardize=False).fit_transform(tmpAROMA)
     csfts = NiftiLabelsMasker(
         labels_img=tmpAROMAcsf, detrend=False,
         standardize=False).fit_transform(tmpAROMA)
     gsts = NiftiLabelsMasker(
         labels_img=curmask, detrend=False,
Exemple #9
0
def sdc_t2b(name='SDC_T2B', icorr=True, num_threads=1):
    """
    The T2w-registration based method (T2B) implements an SDC by nonlinear
    registration of the anatomically correct *T2w* image to the *b0* image
    of the *dMRI* dataset. The implementation here tries to reproduce the one
    included in ExploreDTI `(Leemans et al., 2009)
    <http://www.exploredti.com/ref/ExploreDTI_ISMRM_2009.pdf>`_, which is
    also used by `(Irfanoglu et al., 2012)
    <http://dx.doi.org/10.1016/j.neuroimage.2012.02.054>`_.

    :param str name: a unique name for the workflow.

    :inputs:

        * in_t2w: the reference T2w image

    :outputs:

        * outputnode.corrected_image: the dMRI image after correction


    Example::

    >>> t2b = sdc_t2b()
    >>> t2b.inputs.inputnode.in_dwi = 'dwi_brain.nii'
    >>> t2b.inputs.inputnode.in_bval = 'dwi.bval'
    >>> t2b.inputs.inputnode.in_mask = 'b0_mask.nii'
    >>> t2b.inputs.inputnode.in_t2w = 't2w_brain.nii'
    >>> t2b.inputs.inputnode.in_param = 'parameters.txt'
    >>> t2b.run() # doctest: +SKIP

    """
    inputnode = pe.Node(niu.IdentityInterface(
        fields=['in_dwi', 'in_bval', 'in_t2w', 'dwi_mask', 't2w_mask',
                'in_param', 'in_surf']), name='inputnode')
    outputnode = pe.Node(niu.IdentityInterface(
        fields=['dwi', 'dwi_mask', 'out_surf']), name='outputnode')

    avg_b0 = pe.Node(niu.Function(
        input_names=['in_dwi', 'in_bval'], output_names=['out_file'],
        function=b0_average), name='AverageB0')
    n4_b0 = pe.Node(N4BiasFieldCorrection(dimension=3), name='BiasB0')
    n4_t2 = pe.Node(N4BiasFieldCorrection(dimension=3), name='BiasT2')

    getparam = pe.Node(nio.JSONFileGrabber(defaults={'enc_dir': 'y'}),
                       name='GetEncDir')
    reg = pe.Node(nex.Registration(num_threads=1), name='Elastix')
    tfx_b0 = pe.Node(nex.EditTransform(), name='tfm_b0')
    split_dwi = pe.Node(fsl.utils.Split(dimension='t'), name='split_dwi')
    warp = pe.MapNode(nex.ApplyWarp(), iterfield=['moving_image'],
                      name='UnwarpDWIs')
    warp_prop = pe.Node(nex.AnalyzeWarp(), name='DisplFieldAnalysis')
    warpbuff = pe.Node(niu.IdentityInterface(fields=['unwarped']),
                       name='UnwarpedCache')
    mskdwis = pe.MapNode(fs.ApplyMask(), iterfield='in_file', name='MaskDWIs')
    thres = pe.MapNode(Threshold(thresh=0.0), iterfield=['in_file'],
                       name='RemoveNegs')
    merge_dwi = pe.Node(fsl.utils.Merge(dimension='t'), name='merge_dwis')
    tfx_msk = pe.Node(nex.EditTransform(
        interpolation='nearest', output_type='unsigned char'),
        name='MSKInterpolator')
    corr_msk = pe.Node(nex.ApplyWarp(), name='UnwarpMsk')
    closmsk = pe.Node(fsl.maths.MathsCommand(
        nan2zeros=True, args='-kernel sphere 3 -dilM -kernel sphere 2 -ero'),
        name='MaskClosing')

    swarp = pe.MapNode(nex.PointsWarp(), iterfield=['points_file'],
                       name='UnwarpSurfs')

    wf = pe.Workflow(name=name)
    wf.connect([
        (inputnode,     avg_b0, [('in_dwi', 'in_dwi'),
                                 ('in_bval', 'in_bval')]),
        (inputnode,   getparam, [('in_param', 'in_file')]),
        (inputnode,  split_dwi, [('in_dwi', 'in_file')]),
        (inputnode,   corr_msk, [('dwi_mask', 'moving_image')]),
        (inputnode,      swarp, [('in_surf', 'points_file')]),
        (inputnode,        reg, [('t2w_mask', 'fixed_mask'),
                                 ('dwi_mask', 'moving_mask')]),
        (inputnode,      n4_t2, [('in_t2w', 'input_image'),
                                 ('t2w_mask', 'mask_image')]),
        (inputnode,      n4_b0, [('dwi_mask', 'mask_image')]),
        (avg_b0,         n4_b0, [('out_file', 'input_image')]),
        (getparam,         reg, [
            (('enc_dir', _default_params), 'parameters')]),
        (n4_t2,            reg, [('output_image', 'fixed_image')]),
        (n4_b0,            reg, [('output_image', 'moving_image')]),
        (reg,           tfx_b0, [
            (('transform', _get_last), 'transform_file')]),
        (avg_b0,        tfx_b0, [('out_file', 'reference_image')]),
        (tfx_b0,     warp_prop, [('output_file', 'transform_file')]),
        (tfx_b0,          warp, [('output_file', 'transform_file')]),
        (split_dwi,       warp, [('out_files', 'moving_image')]),
        (warpbuff,     mskdwis, [('unwarped', 'in_file')]),
        (closmsk,      mskdwis, [('out_file', 'mask_file')]),
        (mskdwis,        thres, [('out_file', 'in_file')]),
        (thres,      merge_dwi, [('out_file', 'in_files')]),
        (reg,          tfx_msk, [
            (('transform', _get_last), 'transform_file')]),
        (tfx_b0,         swarp, [('output_file', 'transform_file')]),
        (avg_b0,       tfx_msk, [('out_file', 'reference_image')]),
        (tfx_msk,     corr_msk, [('output_file', 'transform_file')]),
        (corr_msk,     closmsk, [('warped_file', 'in_file')]),
        (merge_dwi, outputnode, [('merged_file', 'dwi')]),
        (closmsk,   outputnode, [('out_file', 'dwi_mask')]),
        (warp_prop, outputnode, [('jacdet_map', 'jacobian')]),
        (swarp,     outputnode, [('warped_file', 'out_surf')])
    ])

    if icorr:
        jac_mask = pe.Node(fs.ApplyMask(), name='mask_jac')
        mult = pe.MapNode(MultiImageMaths(op_string='-mul %s'),
                          iterfield=['in_file'], name='ModulateDWIs')
        wf.connect([
            (closmsk,      jac_mask, [('out_file', 'mask_file')]),
            (warp_prop,    jac_mask, [('jacdet_map', 'in_file')]),
            (warp,             mult, [('warped_file', 'in_file')]),
            (jac_mask,         mult, [('out_file', 'operand_files')]),
            (mult,         warpbuff, [('out_file', 'unwarped')])
        ])
    else:
        wf.connect([
            (warp,         warpbuff, [('warped_file', 'unwarped')])
        ])

    return wf
                 name='Pre_Merge_Tissues')

merge_tissues = Node(MultiImageMaths(), name="Merge_C1_C2_C3")
merge_tissues.inputs.op_string = "-add %s -add %s -thr 0.05 -bin"

fill_mask = Node(UnaryMaths(), name="FillHoles_Mask")
fill_mask.inputs.operation = "fillh"

apply_mask_t1 = Node(ApplyMask(), name="ApplyMask_T1")
apply_mask_flair = Node(ApplyMask(), name="ApplyMask_FLAIR")
apply_mask_swi = Node(ApplyMask(), name="ApplyMask_SWI")
apply_mask_bct1 = Node(ApplyMask(), name="ApplyMask_BiasCorrect_T1")

###SNR
#Tissue 1-3 mask construction and HeadMask construction.
con_tissue_mask_1 = Node(Threshold(), name="Tissue1_Mask")
con_tissue_mask_1.inputs.thresh = 0.1
con_tissue_mask_1.inputs.args = "-bin"
con_tissue_mask_2 = Node(Threshold(), name="Tissue2_Mask")
con_tissue_mask_2.inputs.thresh = 0.1
con_tissue_mask_2.inputs.args = "-bin"
con_tissue_mask_3 = Node(Threshold(), name="Tissue3_Mask")
con_tissue_mask_3.inputs.thresh = 0.1
con_tissue_mask_3.inputs.args = "-bin"


def extract_tissue_c12345(c1, c2, c3, c4, c5):

    first_tissue = c1
    string_list = [c2, c3, c4, c5]
    return (first_tissue, string_list)