Exemple #1
0
def execute():  #pylint: disable=unused-variable
    bzero_threshold = float(
        CONFIG['BZeroThreshold']) if 'BZeroThreshold' in CONFIG else 10.0

    # CHECK INPUTS AND OPTIONS
    app.console('-------')

    # Get b-values and number of volumes per b-value.
    bvalues = [
        int(round(float(x)))
        for x in image.mrinfo('dwi.mif', 'shell_bvalues').split()
    ]
    bvolumes = [int(x) for x in image.mrinfo('dwi.mif', 'shell_sizes').split()]
    app.console(
        str(len(bvalues)) + ' unique b-value(s) detected: ' +
        ','.join(map(str, bvalues)) + ' with ' + ','.join(map(str, bvolumes)) +
        ' volumes')
    if len(bvalues) < 2:
        raise MRtrixError('Need at least 2 unique b-values (including b=0).')
    bvalues_option = ' -shells ' + ','.join(map(str, bvalues))

    # Get lmax information (if provided).
    sfwm_lmax = []
    if app.ARGS.lmax:
        sfwm_lmax = [int(x.strip()) for x in app.ARGS.lmax.split(',')]
        if not len(sfwm_lmax) == len(bvalues):
            raise MRtrixError('Number of lmax\'s (' + str(len(sfwm_lmax)) +
                              ', as supplied to the -lmax option: ' +
                              ','.join(map(str, sfwm_lmax)) +
                              ') does not match number of unique b-values.')
        for sfl in sfwm_lmax:
            if sfl % 2:
                raise MRtrixError(
                    'Values supplied to the -lmax option must be even.')
            if sfl < 0:
                raise MRtrixError(
                    'Values supplied to the -lmax option must be non-negative.'
                )
    sfwm_lmax_option = ''
    if sfwm_lmax:
        sfwm_lmax_option = ' -lmax ' + ','.join(map(str, sfwm_lmax))

    # PREPARATION
    app.console('-------')
    app.console('Preparation:')

    # Erode (brain) mask.
    if app.ARGS.erode > 0:
        app.console('* Eroding brain mask by ' + str(app.ARGS.erode) +
                    ' pass(es)...')
        run.command('maskfilter mask.mif erode eroded_mask.mif -npass ' +
                    str(app.ARGS.erode),
                    show=False)
    else:
        app.console('Not eroding brain mask.')
        run.command('mrconvert mask.mif eroded_mask.mif -datatype bit',
                    show=False)
    statmaskcount = image.statistics('mask.mif', mask='mask.mif').count
    statemaskcount = image.statistics('eroded_mask.mif',
                                      mask='eroded_mask.mif').count
    app.console('  [ mask: ' + str(statmaskcount) + ' -> ' +
                str(statemaskcount) + ' ]')

    # Get volumes, compute mean signal and SDM per b-value; compute overall SDM; get rid of erroneous values.
    app.console('* Computing signal decay metric (SDM):')
    totvolumes = 0
    fullsdmcmd = 'mrcalc'
    errcmd = 'mrcalc'
    zeropath = 'mean_b' + str(bvalues[0]) + '.mif'
    for ibv, bval in enumerate(bvalues):
        app.console(' * b=' + str(bval) + '...')
        meanpath = 'mean_b' + str(bval) + '.mif'
        run.command('dwiextract dwi.mif -shells ' + str(bval) +
                    ' - | mrcalc - 0 -max - | mrmath - mean ' + meanpath +
                    ' -axis 3',
                    show=False)
        errpath = 'err_b' + str(bval) + '.mif'
        run.command('mrcalc ' + meanpath + ' -finite ' + meanpath +
                    ' 0 -if 0 -le ' + errpath + ' -datatype bit',
                    show=False)
        errcmd += ' ' + errpath
        if ibv > 0:
            errcmd += ' -add'
            sdmpath = 'sdm_b' + str(bval) + '.mif'
            run.command('mrcalc ' + zeropath + ' ' + meanpath +
                        ' -divide -log ' + sdmpath,
                        show=False)
            totvolumes += bvolumes[ibv]
            fullsdmcmd += ' ' + sdmpath + ' ' + str(bvolumes[ibv]) + ' -mult'
            if ibv > 1:
                fullsdmcmd += ' -add'
    fullsdmcmd += ' ' + str(totvolumes) + ' -divide full_sdm.mif'
    run.command(fullsdmcmd, show=False)
    app.console('* Removing erroneous voxels from mask and correcting SDM...')
    run.command(
        'mrcalc full_sdm.mif -finite full_sdm.mif 0 -if 0 -le err_sdm.mif -datatype bit',
        show=False)
    errcmd += ' err_sdm.mif -add 0 eroded_mask.mif -if safe_mask.mif -datatype bit'
    run.command(errcmd, show=False)
    run.command('mrcalc safe_mask.mif full_sdm.mif 0 -if 10 -min safe_sdm.mif',
                show=False)
    statsmaskcount = image.statistics('safe_mask.mif',
                                      mask='safe_mask.mif').count
    app.console('  [ mask: ' + str(statemaskcount) + ' -> ' +
                str(statsmaskcount) + ' ]')

    # CRUDE SEGMENTATION
    app.console('-------')
    app.console('Crude segmentation:')

    # Compute FA and principal eigenvectors; crude WM versus GM-CSF separation based on FA.
    app.console('* Crude WM versus GM-CSF separation (at FA=' +
                str(app.ARGS.fa) + ')...')
    run.command(
        'dwi2tensor dwi.mif - -mask safe_mask.mif | tensor2metric - -fa safe_fa.mif -vector safe_vecs.mif -modulate none -mask safe_mask.mif',
        show=False)
    run.command('mrcalc safe_mask.mif safe_fa.mif 0 -if ' + str(app.ARGS.fa) +
                ' -gt crude_wm.mif -datatype bit',
                show=False)
    run.command(
        'mrcalc crude_wm.mif 0 safe_mask.mif -if _crudenonwm.mif -datatype bit',
        show=False)
    statcrudewmcount = image.statistics('crude_wm.mif',
                                        mask='crude_wm.mif').count
    statcrudenonwmcount = image.statistics('_crudenonwm.mif',
                                           mask='_crudenonwm.mif').count
    app.console('  [ ' + str(statsmaskcount) + ' -> ' + str(statcrudewmcount) +
                ' (WM) & ' + str(statcrudenonwmcount) + ' (GM-CSF) ]')

    # Crude GM versus CSF separation based on SDM.
    app.console('* Crude GM versus CSF separation...')
    crudenonwmmedian = image.statistics('safe_sdm.mif',
                                        mask='_crudenonwm.mif').median
    run.command(
        'mrcalc _crudenonwm.mif safe_sdm.mif ' + str(crudenonwmmedian) +
        ' -subtract 0 -if - | mrthreshold - - -mask _crudenonwm.mif | mrcalc _crudenonwm.mif - 0 -if crude_csf.mif -datatype bit',
        show=False)
    run.command(
        'mrcalc crude_csf.mif 0 _crudenonwm.mif -if crude_gm.mif -datatype bit',
        show=False)
    statcrudegmcount = image.statistics('crude_gm.mif',
                                        mask='crude_gm.mif').count
    statcrudecsfcount = image.statistics('crude_csf.mif',
                                         mask='crude_csf.mif').count
    app.console('  [ ' + str(statcrudenonwmcount) + ' -> ' +
                str(statcrudegmcount) + ' (GM) & ' + str(statcrudecsfcount) +
                ' (CSF) ]')

    # REFINED SEGMENTATION
    app.console('-------')
    app.console('Refined segmentation:')

    # Refine WM: remove high SDM outliers.
    app.console('* Refining WM...')
    crudewmmedian = image.statistics('safe_sdm.mif',
                                     mask='crude_wm.mif').median
    run.command('mrcalc crude_wm.mif safe_sdm.mif ' + str(crudewmmedian) +
                ' -subtract -abs 0 -if _crudewm_sdmad.mif',
                show=False)
    crudewmmad = image.statistics('_crudewm_sdmad.mif',
                                  mask='crude_wm.mif').median
    crudewmoutlthresh = crudewmmedian + (1.4826 * crudewmmad * 2.0)
    run.command('mrcalc crude_wm.mif safe_sdm.mif 0 -if ' +
                str(crudewmoutlthresh) +
                ' -gt _crudewmoutliers.mif -datatype bit',
                show=False)
    run.command(
        'mrcalc _crudewmoutliers.mif 0 crude_wm.mif -if refined_wm.mif -datatype bit',
        show=False)
    statrefwmcount = image.statistics('refined_wm.mif',
                                      mask='refined_wm.mif').count
    app.console('  [ WM: ' + str(statcrudewmcount) + ' -> ' +
                str(statrefwmcount) + ' ]')

    # Refine GM: separate safer GM from partial volumed voxels.
    app.console('* Refining GM...')
    crudegmmedian = image.statistics('safe_sdm.mif',
                                     mask='crude_gm.mif').median
    run.command('mrcalc crude_gm.mif safe_sdm.mif 0 -if ' +
                str(crudegmmedian) + ' -gt _crudegmhigh.mif -datatype bit',
                show=False)
    run.command(
        'mrcalc _crudegmhigh.mif 0 crude_gm.mif -if _crudegmlow.mif -datatype bit',
        show=False)
    run.command(
        'mrcalc _crudegmhigh.mif safe_sdm.mif ' + str(crudegmmedian) +
        ' -subtract 0 -if - | mrthreshold - - -mask _crudegmhigh.mif -invert | mrcalc _crudegmhigh.mif - 0 -if _crudegmhighselect.mif -datatype bit',
        show=False)
    run.command(
        'mrcalc _crudegmlow.mif safe_sdm.mif ' + str(crudegmmedian) +
        ' -subtract -neg 0 -if - | mrthreshold - - -mask _crudegmlow.mif -invert | mrcalc _crudegmlow.mif - 0 -if _crudegmlowselect.mif -datatype bit',
        show=False)
    run.command(
        'mrcalc _crudegmhighselect.mif 1 _crudegmlowselect.mif -if refined_gm.mif -datatype bit',
        show=False)
    statrefgmcount = image.statistics('refined_gm.mif',
                                      mask='refined_gm.mif').count
    app.console('  [ GM: ' + str(statcrudegmcount) + ' -> ' +
                str(statrefgmcount) + ' ]')

    # Refine CSF: recover lost CSF from crude WM SDM outliers, separate safer CSF from partial volumed voxels.
    app.console('* Refining CSF...')
    crudecsfmin = image.statistics('safe_sdm.mif', mask='crude_csf.mif').min
    run.command('mrcalc _crudewmoutliers.mif safe_sdm.mif 0 -if ' +
                str(crudecsfmin) +
                ' -gt 1 crude_csf.mif -if _crudecsfextra.mif -datatype bit',
                show=False)
    run.command(
        'mrcalc _crudecsfextra.mif safe_sdm.mif ' + str(crudecsfmin) +
        ' -subtract 0 -if - | mrthreshold - - -mask _crudecsfextra.mif | mrcalc _crudecsfextra.mif - 0 -if refined_csf.mif -datatype bit',
        show=False)
    statrefcsfcount = image.statistics('refined_csf.mif',
                                       mask='refined_csf.mif').count
    app.console('  [ CSF: ' + str(statcrudecsfcount) + ' -> ' +
                str(statrefcsfcount) + ' ]')

    # FINAL VOXEL SELECTION AND RESPONSE FUNCTION ESTIMATION
    app.console('-------')
    app.console('Final voxel selection and response function estimation:')

    # Get final voxels for CSF response function estimation from refined CSF.
    app.console('* CSF:')
    app.console(' * Selecting final voxels (' + str(app.ARGS.csf) +
                '% of refined CSF)...')
    voxcsfcount = int(round(statrefcsfcount * app.ARGS.csf / 100.0))
    run.command(
        'mrcalc refined_csf.mif safe_sdm.mif 0 -if - | mrthreshold - - -top ' +
        str(voxcsfcount) +
        ' -ignorezero | mrcalc refined_csf.mif - 0 -if - -datatype bit | mrconvert - voxels_csf.mif -axes 0,1,2',
        show=False)
    statvoxcsfcount = image.statistics('voxels_csf.mif',
                                       mask='voxels_csf.mif').count
    app.console('   [ CSF: ' + str(statrefcsfcount) + ' -> ' +
                str(statvoxcsfcount) + ' ]')
    # Estimate CSF response function
    app.console(' * Estimating response function...')
    run.command(
        'amp2response dwi.mif voxels_csf.mif safe_vecs.mif response_csf.txt' +
        bvalues_option + ' -isotropic',
        show=False)

    # Get final voxels for GM response function estimation from refined GM.
    app.console('* GM:')
    app.console(' * Selecting final voxels (' + str(app.ARGS.gm) +
                '% of refined GM)...')
    voxgmcount = int(round(statrefgmcount * app.ARGS.gm / 100.0))
    refgmmedian = image.statistics('safe_sdm.mif',
                                   mask='refined_gm.mif').median
    run.command(
        'mrcalc refined_gm.mif safe_sdm.mif ' + str(refgmmedian) +
        ' -subtract -abs 1 -add 0 -if - | mrthreshold - - -bottom ' +
        str(voxgmcount) +
        ' -ignorezero | mrcalc refined_gm.mif - 0 -if - -datatype bit | mrconvert - voxels_gm.mif -axes 0,1,2',
        show=False)
    statvoxgmcount = image.statistics('voxels_gm.mif',
                                      mask='voxels_gm.mif').count
    app.console('   [ GM: ' + str(statrefgmcount) + ' -> ' +
                str(statvoxgmcount) + ' ]')
    # Estimate GM response function
    app.console(' * Estimating response function...')
    run.command(
        'amp2response dwi.mif voxels_gm.mif safe_vecs.mif response_gm.txt' +
        bvalues_option + ' -isotropic',
        show=False)

    # Get final voxels for single-fibre WM response function estimation from refined WM.
    app.console('* Single-fibre WM:')
    app.console(' * Selecting final voxels' +
                ('' if app.ARGS.wm_algo == 'tax' else
                 (' (' + str(app.ARGS.sfwm) + '% of refined WM)')) + '...')
    voxsfwmcount = int(round(statrefwmcount * app.ARGS.sfwm / 100.0))

    if app.ARGS.wm_algo:
        recursive_cleanup_option = ''
        if not app.DO_CLEANUP:
            recursive_cleanup_option = ' -nocleanup'
        app.console('   Selecting WM single-fibre voxels using \'' +
                    app.ARGS.wm_algo + '\' algorithm')
        if app.ARGS.wm_algo == 'tax' and app.ARGS.sfwm != 0.5:
            app.warn(
                'Single-fibre WM response function selection algorithm "tax" will not honour requested WM voxel percentage'
            )
        run.command(
            'dwi2response ' + app.ARGS.wm_algo +
            ' dwi.mif _respsfwmss.txt -mask refined_wm.mif -voxels voxels_sfwm.mif'
            + ('' if app.ARGS.wm_algo == 'tax' else
               (' -number ' + str(voxsfwmcount))) + ' -scratch ' +
            path.quote(app.SCRATCH_DIR) + recursive_cleanup_option,
            show=False)
    else:
        app.console(
            '   Selecting WM single-fibre voxels using built-in (Dhollander et al., 2019) algorithm'
        )
        run.command('mrmath dwi.mif mean mean_sig.mif -axis 3', show=False)
        refwmcoef = image.statistics('mean_sig.mif',
                                     mask='refined_wm.mif').median * math.sqrt(
                                         4.0 * math.pi)
        if sfwm_lmax:
            isiso = [lm == 0 for lm in sfwm_lmax]
        else:
            isiso = [bv < bzero_threshold for bv in bvalues]
        with open('ewmrf.txt', 'w') as ewr:
            for iis in isiso:
                if iis:
                    ewr.write("%s 0 0 0\n" % refwmcoef)
                else:
                    ewr.write("%s -%s %s -%s\n" %
                              (refwmcoef, refwmcoef, refwmcoef, refwmcoef))
        run.command(
            'dwi2fod msmt_csd dwi.mif ewmrf.txt abs_ewm2.mif response_csf.txt abs_csf2.mif -mask refined_wm.mif -lmax 2,0'
            + bvalues_option,
            show=False)
        run.command(
            'mrconvert abs_ewm2.mif - -coord 3 0 | mrcalc - abs_csf2.mif -add abs_sum2.mif',
            show=False)
        run.command(
            'sh2peaks abs_ewm2.mif - -num 1 -mask refined_wm.mif | peaks2amp - - | mrcalc - abs_sum2.mif -divide - | mrconvert - metric_sfwm2.mif -coord 3 0 -axes 0,1,2',
            show=False)
        run.command(
            'mrcalc refined_wm.mif metric_sfwm2.mif 0 -if - | mrthreshold - - -top '
            + str(voxsfwmcount * 2) +
            ' -ignorezero | mrcalc refined_wm.mif - 0 -if - -datatype bit | mrconvert - refined_sfwm.mif -axes 0,1,2',
            show=False)
        run.command(
            'dwi2fod msmt_csd dwi.mif ewmrf.txt abs_ewm6.mif response_csf.txt abs_csf6.mif -mask refined_sfwm.mif -lmax 6,0'
            + bvalues_option,
            show=False)
        run.command(
            'mrconvert abs_ewm6.mif - -coord 3 0 | mrcalc - abs_csf6.mif -add abs_sum6.mif',
            show=False)
        run.command(
            'sh2peaks abs_ewm6.mif - -num 1 -mask refined_sfwm.mif | peaks2amp - - | mrcalc - abs_sum6.mif -divide - | mrconvert - metric_sfwm6.mif -coord 3 0 -axes 0,1,2',
            show=False)
        run.command(
            'mrcalc refined_sfwm.mif metric_sfwm6.mif 0 -if - | mrthreshold - - -top '
            + str(voxsfwmcount) +
            ' -ignorezero | mrcalc refined_sfwm.mif - 0 -if - -datatype bit | mrconvert - voxels_sfwm.mif -axes 0,1,2',
            show=False)

    statvoxsfwmcount = image.statistics('voxels_sfwm.mif',
                                        mask='voxels_sfwm.mif').count
    app.console('   [ WM: ' + str(statrefwmcount) + ' -> ' +
                str(statvoxsfwmcount) + ' (single-fibre) ]')
    # Estimate SF WM response function
    app.console(' * Estimating response function...')
    run.command(
        'amp2response dwi.mif voxels_sfwm.mif safe_vecs.mif response_sfwm.txt'
        + bvalues_option + sfwm_lmax_option,
        show=False)

    # OUTPUT AND SUMMARY
    app.console('-------')
    app.console('Generating outputs...')

    # Generate 4D binary images with voxel selections at major stages in algorithm (RGB: WM=blue, GM=green, CSF=red).
    run.command(
        'mrcat crude_csf.mif crude_gm.mif crude_wm.mif check_crude.mif -axis 3',
        show=False)
    run.command(
        'mrcat refined_csf.mif refined_gm.mif refined_wm.mif check_refined.mif -axis 3',
        show=False)
    run.command(
        'mrcat voxels_csf.mif voxels_gm.mif voxels_sfwm.mif check_voxels.mif -axis 3',
        show=False)

    # Copy results to output files
    run.function(shutil.copyfile,
                 'response_sfwm.txt',
                 path.from_user(app.ARGS.out_sfwm, False),
                 show=False)
    run.function(shutil.copyfile,
                 'response_gm.txt',
                 path.from_user(app.ARGS.out_gm, False),
                 show=False)
    run.function(shutil.copyfile,
                 'response_csf.txt',
                 path.from_user(app.ARGS.out_csf, False),
                 show=False)
    if app.ARGS.voxels:
        run.command('mrconvert check_voxels.mif ' +
                    path.from_user(app.ARGS.voxels),
                    mrconvert_keyval=path.from_user(app.ARGS.input, False),
                    force=app.FORCE_OVERWRITE,
                    show=False)
    app.console('-------')
Exemple #2
0
 def quote_nonpipe(item):
     return item if item == '|' else path.quote(item)
Exemple #3
0
def execute(): #pylint: disable=unused-variable
  # Ideally want to use the oversampling-based regridding of the 5TT image from the SIFT model, not mrtransform
  # May need to commit 5ttregrid...

  # Verify input 5tt image
  verification_text = ''
  try:
    verification_text = run.command('5ttcheck 5tt.mif').stderr
  except run.MRtrixCmdError as except_5ttcheck:
    verification_text = except_5ttcheck.stderr
  if 'WARNING' in verification_text or 'ERROR' in verification_text:
    app.warn('Command 5ttcheck indicates problems with provided input 5TT image \'' + app.ARGS.in_5tt + '\':')
    for line in verification_text.splitlines():
      app.warn(line)
    app.warn('These may or may not interfere with the dwi2response msmt_5tt script')

  # Get shell information
  shells = [ int(round(float(x))) for x in image.mrinfo('dwi.mif', 'shell_bvalues').split() ]
  if len(shells) < 3:
    app.warn('Less than three b-values; response functions will not be applicable in resolving three tissues using MSMT-CSD algorithm')

  # Get lmax information (if provided)
  wm_lmax = [ ]
  if app.ARGS.lmax:
    wm_lmax = [ int(x.strip()) for x in app.ARGS.lmax.split(',') ]
    if not len(wm_lmax) == len(shells):
      raise MRtrixError('Number of manually-defined lmax\'s (' + str(len(wm_lmax)) + ') does not match number of b-values (' + str(len(shells)) + ')')
    for shell_l in wm_lmax:
      if shell_l % 2:
        raise MRtrixError('Values for lmax must be even')
      if shell_l < 0:
        raise MRtrixError('Values for lmax must be non-negative')

  run.command('dwi2tensor dwi.mif - -mask mask.mif | tensor2metric - -fa fa.mif -vector vector.mif')
  if not os.path.exists('dirs.mif'):
    run.function(shutil.copy, 'vector.mif', 'dirs.mif')
  run.command('mrtransform 5tt.mif 5tt_regrid.mif -template fa.mif -interp linear')

  # Basic tissue masks
  run.command('mrconvert 5tt_regrid.mif - -coord 3 2 -axes 0,1,2 | mrcalc - ' + str(app.ARGS.pvf) + ' -gt mask.mif -mult wm_mask.mif')
  run.command('mrconvert 5tt_regrid.mif - -coord 3 0 -axes 0,1,2 | mrcalc - ' + str(app.ARGS.pvf) + ' -gt fa.mif ' + str(app.ARGS.fa) + ' -lt -mult mask.mif -mult gm_mask.mif')
  run.command('mrconvert 5tt_regrid.mif - -coord 3 3 -axes 0,1,2 | mrcalc - ' + str(app.ARGS.pvf) + ' -gt fa.mif ' + str(app.ARGS.fa) + ' -lt -mult mask.mif -mult csf_mask.mif')

  # Revise WM mask to only include single-fibre voxels
  recursive_cleanup_option=''
  if not app.DO_CLEANUP:
    recursive_cleanup_option = ' -nocleanup'
  if not app.ARGS.sfwm_fa_threshold:
    app.console('Selecting WM single-fibre voxels using \'' + app.ARGS.wm_algo + '\' algorithm')
    run.command('dwi2response ' + app.ARGS.wm_algo + ' dwi.mif wm_ss_response.txt -mask wm_mask.mif -voxels wm_sf_mask.mif -scratch ' + path.quote(app.SCRATCH_DIR) + recursive_cleanup_option)
  else:
    app.console('Selecting WM single-fibre voxels using \'fa\' algorithm with a hard FA threshold of ' + str(app.ARGS.sfwm_fa_threshold))
    run.command('dwi2response fa dwi.mif wm_ss_response.txt -mask wm_mask.mif -threshold ' + str(app.ARGS.sfwm_fa_threshold) + ' -voxels wm_sf_mask.mif -scratch ' + path.quote(app.SCRATCH_DIR) + recursive_cleanup_option)

  # Check for empty masks
  wm_voxels  = image.statistics('wm_sf_mask.mif', mask='wm_sf_mask.mif').count
  gm_voxels  = image.statistics('gm_mask.mif',    mask='gm_mask.mif').count
  csf_voxels = image.statistics('csf_mask.mif',   mask='csf_mask.mif').count
  empty_masks = [ ]
  if not wm_voxels:
    empty_masks.append('WM')
  if not gm_voxels:
    empty_masks.append('GM')
  if not csf_voxels:
    empty_masks.append('CSF')
  if empty_masks:
    message = ','.join(empty_masks)
    message += ' tissue mask'
    if len(empty_masks) > 1:
      message += 's'
    message += ' empty; cannot estimate response function'
    if len(empty_masks) > 1:
      message += 's'
    raise MRtrixError(message)

  # For each of the three tissues, generate a multi-shell response
  bvalues_option = ' -shells ' + ','.join(map(str,shells))
  sfwm_lmax_option = ''
  if wm_lmax:
    sfwm_lmax_option = ' -lmax ' + ','.join(map(str,wm_lmax))
  run.command('amp2response dwi.mif wm_sf_mask.mif dirs.mif wm.txt' + bvalues_option + sfwm_lmax_option)
  run.command('amp2response dwi.mif gm_mask.mif dirs.mif gm.txt' + bvalues_option + ' -isotropic')
  run.command('amp2response dwi.mif csf_mask.mif dirs.mif csf.txt' + bvalues_option + ' -isotropic')
  run.function(shutil.copyfile, 'wm.txt',  path.from_user(app.ARGS.out_wm,  False))
  run.function(shutil.copyfile, 'gm.txt',  path.from_user(app.ARGS.out_gm,  False))
  run.function(shutil.copyfile, 'csf.txt', path.from_user(app.ARGS.out_csf, False))

  # Generate output 4D binary image with voxel selections; RGB as in MSMT-CSD paper
  run.command('mrcat csf_mask.mif gm_mask.mif wm_sf_mask.mif voxels.mif -axis 3')
  if app.ARGS.voxels:
    run.command('mrconvert voxels.mif ' + path.from_user(app.ARGS.voxels), mrconvert_keyval=path.from_user(app.ARGS.input, False), force=app.FORCE_OVERWRITE)
Exemple #4
0
def execute():  #pylint: disable=unused-variable
    class Input(object):
        def __init__(self, filename, prefix, mask_filename=''):
            self.filename = filename
            self.prefix = prefix
            self.mask_filename = mask_filename

    input_dir = path.from_user(app.ARGS.input_dir, False)
    if not os.path.exists(input_dir):
        raise MRtrixError('input directory not found')
    in_files = path.all_in_dir(input_dir, dir_path=False)
    if len(in_files) <= 1:
        raise MRtrixError(
            'not enough images found in input directory: more than one image is needed to perform a group-wise intensity normalisation'
        )

    app.console('performing global intensity normalisation on ' +
                str(len(in_files)) + ' input images')

    mask_dir = path.from_user(app.ARGS.mask_dir, False)
    if not os.path.exists(mask_dir):
        raise MRtrixError('mask directory not found')
    mask_files = path.all_in_dir(mask_dir, dir_path=False)
    if len(mask_files) != len(in_files):
        raise MRtrixError(
            'the number of images in the mask directory does not equal the number of images in the input directory'
        )
    mask_common_postfix = os.path.commonprefix([i[::-1]
                                                for i in mask_files])[::-1]
    mask_prefixes = []
    for mask_file in mask_files:
        mask_prefixes.append(mask_file.split(mask_common_postfix)[0])

    common_postfix = os.path.commonprefix([i[::-1] for i in in_files])[::-1]
    input_list = []
    for i in in_files:
        subj_prefix = i.split(common_postfix)[0]
        if subj_prefix not in mask_prefixes:
            raise MRtrixError(
                'no matching mask image was found for input image ' + i)
        image.check_3d_nonunity(os.path.join(input_dir, i))
        index = mask_prefixes.index(subj_prefix)
        input_list.append(Input(i, subj_prefix, mask_files[index]))

    app.make_scratch_dir()
    app.goto_scratch_dir()

    path.make_dir('fa')
    progress = app.ProgressBar('Computing FA images', len(input_list))
    for i in input_list:
        run.command('dwi2tensor ' +
                    path.quote(os.path.join(input_dir, i.filename)) +
                    ' -mask ' +
                    path.quote(os.path.join(mask_dir, i.mask_filename)) +
                    ' - | tensor2metric - -fa ' +
                    os.path.join('fa', i.prefix + '.mif'))
        progress.increment()
    progress.done()

    app.console('Generating FA population template')
    run.command('population_template fa fa_template.mif' + ' -mask_dir ' +
                mask_dir + ' -type rigid_affine_nonlinear' +
                ' -rigid_scale 0.25,0.5,0.8,1.0' +
                ' -affine_scale 0.7,0.8,1.0,1.0' +
                ' -nl_scale 0.5,0.75,1.0,1.0,1.0' + ' -nl_niter 5,5,5,5,5' +
                ' -warp_dir warps' + ' -linear_no_pause' +
                ' -scratch population_template' +
                ('' if app.DO_CLEANUP else ' -nocleanup'))

    app.console('Generating WM mask in template space')
    run.command('mrthreshold fa_template.mif -abs ' + app.ARGS.fa_threshold +
                ' template_wm_mask.mif')

    progress = app.ProgressBar('Intensity normalising subject images',
                               len(input_list))
    path.make_dir(path.from_user(app.ARGS.output_dir, False))
    path.make_dir('wm_mask_warped')
    for i in input_list:
        run.command(
            'mrtransform template_wm_mask.mif -interp nearest -warp_full ' +
            os.path.join('warps', i.prefix + '.mif') + ' ' +
            os.path.join('wm_mask_warped', i.prefix + '.mif') +
            ' -from 2 -template ' + os.path.join('fa', i.prefix + '.mif'))
        run.command('dwinormalise individual ' +
                    path.quote(os.path.join(input_dir, i.filename)) + ' ' +
                    os.path.join('wm_mask_warped', i.prefix + '.mif') +
                    ' temp.mif')
        run.command(
            'mrconvert temp.mif ' +
            path.from_user(os.path.join(app.ARGS.output_dir, i.filename)),
            mrconvert_keyval=path.from_user(
                os.path.join(input_dir, i.filename), False),
            force=app.FORCE_OVERWRITE)
        os.remove('temp.mif')
        progress.increment()
    progress.done()

    app.console('Exporting template images to user locations')
    run.command('mrconvert template_wm_mask.mif ' +
                path.from_user(app.ARGS.wm_mask),
                mrconvert_keyval='NULL',
                force=app.FORCE_OVERWRITE)
    run.command('mrconvert fa_template.mif ' +
                path.from_user(app.ARGS.fa_template),
                mrconvert_keyval='NULL',
                force=app.FORCE_OVERWRITE)
def execute(): #pylint: disable=unused-variable
  from mrtrix3 import app, image, matrix, MRtrixError, path, run


  # CHECK INPUTS AND OPTIONS
  app.console('-------')

  # Get b-values and number of volumes per b-value.
  bvalues = [ int(round(float(x))) for x in image.mrinfo('dwi.mif', 'shell_bvalues').split() ]
  bvolumes = [ int(x) for x in image.mrinfo('dwi.mif', 'shell_sizes').split() ]
  app.console(str(len(bvalues)) + ' unique b-value(s) detected: ' + ','.join(map(str,bvalues)) + ' with ' + ','.join(map(str,bvolumes)) + ' volumes')
  if len(bvalues) < 2:
    raise MRtrixError('Need at least 2 unique b-values (including b=0).')
  bvalues_option = ' -shells ' + ','.join(map(str,bvalues))

  # Get lmax information (if provided).
  sfwm_lmax = [ ]
  if app.ARGS.lmax:
    sfwm_lmax = [ int(x.strip()) for x in app.ARGS.lmax.split(',') ]
    if not len(sfwm_lmax) == len(bvalues):
      raise MRtrixError('Number of lmax\'s (' + str(len(sfwm_lmax)) + ', as supplied to the -lmax option: ' + ','.join(map(str,sfwm_lmax)) + ') does not match number of unique b-values.')
    for sfl in sfwm_lmax:
      if sfl%2:
        raise MRtrixError('Values supplied to the -lmax option must be even.')
      if sfl<0:
        raise MRtrixError('Values supplied to the -lmax option must be non-negative.')
  sfwm_lmax_option = ''
  if sfwm_lmax:
    sfwm_lmax_option = ' -lmax ' + ','.join(map(str,sfwm_lmax))


  # PREPARATION
  app.console('-------')
  app.console('Preparation:')

  # Erode (brain) mask.
  if app.ARGS.erode > 0:
    app.console('* Eroding brain mask by ' + str(app.ARGS.erode) + ' pass(es)...')
    run.command('maskfilter mask.mif erode eroded_mask.mif -npass ' + str(app.ARGS.erode), show=False)
  else:
    app.console('Not eroding brain mask.')
    run.command('mrconvert mask.mif eroded_mask.mif -datatype bit', show=False)
  statmaskcount = image.statistic('mask.mif', 'count', '-mask mask.mif')
  statemaskcount = image.statistic('eroded_mask.mif', 'count', '-mask eroded_mask.mif')
  app.console('  [ mask: ' + str(statmaskcount) + ' -> ' + str(statemaskcount) + ' ]')

  # Get volumes, compute mean signal and SDM per b-value; compute overall SDM; get rid of erroneous values.
  app.console('* Computing signal decay metric (SDM):')
  totvolumes = 0
  fullsdmcmd = 'mrcalc'
  errcmd = 'mrcalc'
  zeropath = 'mean_b' + str(bvalues[0]) + '.mif'
  for ibv, bval in enumerate(bvalues):
    app.console(' * b=' + str(bval) + '...')
    meanpath = 'mean_b' + str(bval) + '.mif'
    run.command('dwiextract dwi.mif -shells ' + str(bval) + ' - | mrmath - mean ' + meanpath + ' -axis 3', show=False)
    errpath = 'err_b' + str(bval) + '.mif'
    run.command('mrcalc ' + meanpath + ' -finite ' + meanpath + ' 0 -if 0 -le ' + errpath + ' -datatype bit', show=False)
    errcmd += ' ' + errpath
    if ibv>0:
      errcmd += ' -add'
      sdmpath = 'sdm_b' + str(bval) + '.mif'
      run.command('mrcalc ' + zeropath + ' ' + meanpath +  ' -divide -log ' + sdmpath, show=False)
      totvolumes += bvolumes[ibv]
      fullsdmcmd += ' ' + sdmpath + ' ' + str(bvolumes[ibv]) + ' -mult'
      if ibv>1:
        fullsdmcmd += ' -add'
  fullsdmcmd += ' ' + str(totvolumes) + ' -divide full_sdm.mif'
  run.command(fullsdmcmd, show=False)
  app.console('* Removing erroneous voxels from mask and correcting SDM...')
  run.command('mrcalc full_sdm.mif -finite full_sdm.mif 0 -if 0 -le err_sdm.mif -datatype bit', show=False)
  errcmd += ' err_sdm.mif -add 0 eroded_mask.mif -if safe_mask.mif -datatype bit'
  run.command(errcmd, show=False)
  run.command('mrcalc safe_mask.mif full_sdm.mif 0 -if 10 -min safe_sdm.mif', show=False)
  statsmaskcount = image.statistic('safe_mask.mif', 'count', '-mask safe_mask.mif')
  app.console('  [ mask: ' + str(statemaskcount) + ' -> ' + str(statsmaskcount) + ' ]')


  # CRUDE SEGMENTATION
  app.console('-------')
  app.console('Crude segmentation:')

  # Compute FA and principal eigenvectors; crude WM versus GM-CSF separation based on FA.
  app.console('* Crude WM versus GM-CSF separation (at FA=' + str(app.ARGS.fa) + ')...')
  run.command('dwi2tensor dwi.mif - -mask safe_mask.mif | tensor2metric - -fa safe_fa.mif -vector safe_vecs.mif -modulate none -mask safe_mask.mif', show=False)
  run.command('mrcalc safe_mask.mif safe_fa.mif 0 -if ' + str(app.ARGS.fa) + ' -gt crude_wm.mif -datatype bit', show=False)
  run.command('mrcalc crude_wm.mif 0 safe_mask.mif -if _crudenonwm.mif -datatype bit', show=False)
  statcrudewmcount = image.statistic('crude_wm.mif', 'count', '-mask crude_wm.mif')
  statcrudenonwmcount = image.statistic('_crudenonwm.mif', 'count', '-mask _crudenonwm.mif')
  app.console('  [ ' + str(statsmaskcount) + ' -> ' + str(statcrudewmcount) + ' (WM) & ' + str(statcrudenonwmcount) + ' (GM-CSF) ]')

  # Crude GM versus CSF separation based on SDM.
  app.console('* Crude GM versus CSF separation...')
  crudenonwmmedian = image.statistic('safe_sdm.mif', 'median', '-mask _crudenonwm.mif')
  run.command('mrcalc _crudenonwm.mif safe_sdm.mif ' + str(crudenonwmmedian) + ' -subtract 0 -if - | mrthreshold - - -mask _crudenonwm.mif | mrcalc _crudenonwm.mif - 0 -if crude_csf.mif -datatype bit', show=False)
  run.command('mrcalc crude_csf.mif 0 _crudenonwm.mif -if crude_gm.mif -datatype bit', show=False)
  statcrudegmcount = image.statistic('crude_gm.mif', 'count', '-mask crude_gm.mif')
  statcrudecsfcount = image.statistic('crude_csf.mif', 'count', '-mask crude_csf.mif')
  app.console('  [ ' + str(statcrudenonwmcount) + ' -> ' + str(statcrudegmcount) + ' (GM) & ' + str(statcrudecsfcount) + ' (CSF) ]')


  # REFINED SEGMENTATION
  app.console('-------')
  app.console('Refined segmentation:')

  # Refine WM: remove high SDM outliers.
  app.console('* Refining WM...')
  crudewmmedian = image.statistic('safe_sdm.mif', 'median', '-mask crude_wm.mif')
  run.command('mrcalc crude_wm.mif safe_sdm.mif ' + str(crudewmmedian) + ' -subtract -abs 0 -if _crudewm_sdmad.mif', show=False)
  crudewmmad = image.statistic('_crudewm_sdmad.mif', 'median', '-mask crude_wm.mif')
  crudewmoutlthresh = crudewmmedian + (1.4826 * crudewmmad * 2.0)
  run.command('mrcalc crude_wm.mif safe_sdm.mif 0 -if ' + str(crudewmoutlthresh) + ' -gt _crudewmoutliers.mif -datatype bit', show=False)
  run.command('mrcalc _crudewmoutliers.mif 0 crude_wm.mif -if refined_wm.mif -datatype bit', show=False)
  statrefwmcount = image.statistic('refined_wm.mif', 'count', '-mask refined_wm.mif')
  app.console('  [ WM: ' + str(statcrudewmcount) + ' -> ' + str(statrefwmcount) + ' ]')

  # Refine GM: separate safer GM from partial volumed voxels.
  app.console('* Refining GM...')
  crudegmmedian = image.statistic('safe_sdm.mif', 'median', '-mask crude_gm.mif')
  run.command('mrcalc crude_gm.mif safe_sdm.mif 0 -if ' + str(crudegmmedian) + ' -gt _crudegmhigh.mif -datatype bit', show=False)
  run.command('mrcalc _crudegmhigh.mif 0 crude_gm.mif -if _crudegmlow.mif -datatype bit', show=False)
  run.command('mrcalc _crudegmhigh.mif safe_sdm.mif ' + str(crudegmmedian) + ' -subtract 0 -if - | mrthreshold - - -mask _crudegmhigh.mif -invert | mrcalc _crudegmhigh.mif - 0 -if _crudegmhighselect.mif -datatype bit', show=False)
  run.command('mrcalc _crudegmlow.mif safe_sdm.mif ' + str(crudegmmedian) + ' -subtract -neg 0 -if - | mrthreshold - - -mask _crudegmlow.mif -invert | mrcalc _crudegmlow.mif - 0 -if _crudegmlowselect.mif -datatype bit', show=False)
  run.command('mrcalc _crudegmhighselect.mif 1 _crudegmlowselect.mif -if refined_gm.mif -datatype bit', show=False)
  statrefgmcount = image.statistic('refined_gm.mif', 'count', '-mask refined_gm.mif')
  app.console('  [ GM: ' + str(statcrudegmcount) + ' -> ' + str(statrefgmcount) + ' ]')

  # Refine CSF: recover lost CSF from crude WM SDM outliers, separate safer CSF from partial volumed voxels.
  app.console('* Refining CSF...')
  crudecsfmin = image.statistic('safe_sdm.mif', 'min', '-mask crude_csf.mif')
  run.command('mrcalc _crudewmoutliers.mif safe_sdm.mif 0 -if ' + str(crudecsfmin) + ' -gt 1 crude_csf.mif -if _crudecsfextra.mif -datatype bit', show=False)
  run.command('mrcalc _crudecsfextra.mif safe_sdm.mif ' + str(crudecsfmin) + ' -subtract 0 -if - | mrthreshold - - -mask _crudecsfextra.mif | mrcalc _crudecsfextra.mif - 0 -if refined_csf.mif -datatype bit', show=False)
  statrefcsfcount = image.statistic('refined_csf.mif', 'count', '-mask refined_csf.mif')
  app.console('  [ CSF: ' + str(statcrudecsfcount) + ' -> ' + str(statrefcsfcount) + ' ]')


  # FINAL VOXEL SELECTION AND RESPONSE FUNCTION ESTIMATION
  app.console('-------')
  app.console('Final voxel selection and response function estimation:')

  # Get final voxels for CSF response function estimation from refined CSF.
  app.console('* CSF:')
  app.console(' * Selecting final voxels (' + str(app.ARGS.csf) + '% of refined CSF)...')
  voxcsfcount = int(round(statrefcsfcount * app.ARGS.csf / 100.0))
  run.command('mrcalc refined_csf.mif safe_sdm.mif 0 -if - | mrthreshold - - -top ' + str(voxcsfcount) + ' -ignorezero | mrcalc refined_csf.mif - 0 -if - -datatype bit | mrconvert - voxels_csf.mif -axes 0,1,2', show=False)
  statvoxcsfcount = image.statistic('voxels_csf.mif', 'count', '-mask voxels_csf.mif')
  app.console('   [ CSF: ' + str(statrefcsfcount) + ' -> ' + str(statvoxcsfcount) + ' ]')
  # Estimate CSF response function
  app.console(' * Estimating response function...')
  run.command('amp2response dwi.mif voxels_csf.mif safe_vecs.mif response_csf.txt' + bvalues_option + ' -isotropic', show=False)

  # Get final voxels for GM response function estimation from refined GM.
  app.console('* GM:')
  app.console(' * Selecting final voxels (' + str(app.ARGS.gm) + '% of refined GM)...')
  voxgmcount = int(round(statrefgmcount * app.ARGS.gm / 100.0))
  refgmmedian = image.statistic('safe_sdm.mif', 'median', '-mask refined_gm.mif')
  run.command('mrcalc refined_gm.mif safe_sdm.mif ' + str(refgmmedian) + ' -subtract -abs 1 -add 0 -if - | mrthreshold - - -bottom ' + str(voxgmcount) + ' -ignorezero | mrcalc refined_gm.mif - 0 -if - -datatype bit | mrconvert - voxels_gm.mif -axes 0,1,2', show=False)
  statvoxgmcount = image.statistic('voxels_gm.mif', 'count', '-mask voxels_gm.mif')
  app.console('   [ GM: ' + str(statrefgmcount) + ' -> ' + str(statvoxgmcount) + ' ]')
  # Estimate GM response function
  app.console(' * Estimating response function...')
  run.command('amp2response dwi.mif voxels_gm.mif safe_vecs.mif response_gm.txt' + bvalues_option + ' -isotropic', show=False)

  # Get final voxels for single-fibre WM response function estimation from WM using TOURNIER algorithm.
  app.console('* single-fibre WM:')
  app.console(' * Selecting final voxels (' + str(app.ARGS.sfwm) + '% of refined WM)...')
  voxsfwmcount = int(round(statrefwmcount * app.ARGS.sfwm / 100.0))
  app.console('   Running TOURNIER algorithm to select ' + str(voxsfwmcount) + ' single-fibre WM voxels.')
  cleanopt = ''
  if not app.DO_CLEANUP:
    cleanopt = ' -nocleanup'
  run.command('dwi2response tournier dwi.mif _respsfwmss.txt -sf_voxels ' + str(voxsfwmcount) + ' -iter_voxels ' + str(voxsfwmcount * 10) + ' -mask refined_wm.mif -voxels voxels_sfwm.mif -scratch ' + path.quote(app.SCRATCH_DIR) + cleanopt, show=False)
  statvoxsfwmcount = image.statistic('voxels_sfwm.mif', 'count', '-mask voxels_sfwm.mif')
  app.console('   [ WM: ' + str(statrefwmcount) + ' -> ' + str(statvoxsfwmcount) + ' (single-fibre by TOURNIER algorithm) ]')
  # Estimate SF WM response function
  app.console(' * Estimating response function...')
  run.command('amp2response dwi.mif voxels_sfwm.mif safe_vecs.mif response_sfwm.txt' + bvalues_option + sfwm_lmax_option, show=False)


  # OUTPUT AND SUMMARY
  app.console('-------')
  app.console('Generating outputs...')

  # Generate 4D binary images with voxel selections at major stages in algorithm (RGB: WM=blue, GM=green, CSF=red).
  run.command('mrcat crude_csf.mif crude_gm.mif crude_wm.mif check_crude.mif -axis 3', show=False)
  run.command('mrcat refined_csf.mif refined_gm.mif refined_wm.mif check_refined.mif -axis 3', show=False)
  run.command('mrcat voxels_csf.mif voxels_gm.mif voxels_sfwm.mif check_voxels.mif -axis 3', show=False)

  # Save results to output files
  bvalhdr = { 'b-values' : ','.join(map(str,bvalues)) }
  matrix.save_matrix(path.from_user(app.ARGS.out_sfwm, False), matrix.load_matrix('response_sfwm.txt'), header=bvalhdr, fmt='%.15g', footer={})
  matrix.save_matrix(path.from_user(app.ARGS.out_gm, False), matrix.load_matrix('response_gm.txt'), header=bvalhdr, fmt='%.15g', footer={})
  matrix.save_matrix(path.from_user(app.ARGS.out_csf, False), matrix.load_matrix('response_csf.txt'), header=bvalhdr, fmt='%.15g', footer={})
  if app.ARGS.voxels:
    run.command('mrconvert check_voxels.mif ' + path.from_user(app.ARGS.voxels), mrconvert_keyval=path.from_user(app.ARGS.input), force=app.FORCE_OVERWRITE, show=False)
  app.console('-------')